refined-gae
Science Score: 54.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
✓Academic publication links
Links to: arxiv.org -
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (4.9%) to scientific vocabulary
Repository
Basic Info
- Host: GitHub
- Owner: yiming421
- Language: Python
- Default Branch: main
- Size: 152 KB
Statistics
- Stars: 0
- Watchers: 0
- Forks: 0
- Open Issues: 0
- Releases: 0
Metadata Files
ReadMe.md
Update: We are thrilled to announce that our paper, "Reconsidering the Performance of GAE in Link Prediction," has been accepted at the 34th ACM International Conference on Information and Knowledge Management (CIKM 2025)!
You can read the pre-print on arXiv:
https://arxiv.org/abs/2411.03845
We achieve comparable or better performance than recent models on the OGB benchmark datasets, including ogbl-ddi, ogbl-collab, ogbl-ppa, and ogbl-citation2:
| Metric | Cora | Citeseer | Pubmed | Collab | PPA | Citation2 | DDI | |---------------|---------------|---------------|---------------|---------------|---------------|---------------|---------------| | | Hits@100 | Hits@100 | Hits@100 | Hits@50 | Hits@100 | MRR | Hits@20 | | CN | $33.92 \pm 0.46$ | $29.79 \pm 0.90$ | $23.13 \pm 0.15$ | $56.44 \pm 0.00$ | $27.65 \pm 0.00$ | $51.47 \pm 0.00$ | $17.73 \pm 0.00$ | | AA | $39.85 \pm 1.34$ | $35.19 \pm 1.33$ | $27.38 \pm 0.11$ | $64.35 \pm 0.00$ | $32.45 \pm 0.00$ | $51.89 \pm 0.00$ | $18.61 \pm 0.00$ | | RA | $41.07 \pm 0.48$ | $33.56 \pm 0.17$ | $27.03 \pm 0.35$ | $64.00 \pm 0.00$ | $49.33 \pm 0.00$ | $51.98 \pm 0.00$ | $27.60 \pm 0.00$ | | SEAL | $81.71 \pm 1.30$ | $83.89 \pm 2.15$ | $75.54 \pm 1.32$ | $64.74 \pm 0.43$ | $48.80 \pm 3.16$ | $87.67 \pm 0.32$ | $30.56 \pm 3.86$ | | NBFNet | $71.65 \pm 2.27$ | $74.07 \pm 1.75$ | $58.73 \pm 1.99$ | OOM | OOM | OOM | $4.00 \pm 0.58$ | | Neo-GNN | $80.42 \pm 1.31$ | $84.67 \pm 2.16$ | $73.93 \pm 1.19$ | $57.52 \pm 0.37$ | $49.13 \pm 0.60$ | $87.26 \pm 0.84$ | $63.57 \pm 3.52$ | | BUDDY | $88.00 \pm 0.44$ | $\mathbf{92.93 \pm 0.27}$ | $74.10 \pm 0.78$ | $65.94 \pm 0.58$ | $49.85 \pm 0.20$ | $87.56 \pm 0.11$ | $78.51 \pm 1.36$ | | NCN | $\mathbf{89.05 \pm 0.96}$ | $91.56 \pm 1.43$ | $\underline{79.05 \pm 1.16}$ | $64.76 \pm 0.87$ | $61.19 \pm 0.85$ | $88.09 \pm 0.06$ | $\underline{82.32 \pm 6.10}$ | | MPLP+ | - | - | - | $\mathbf{66.99 \pm 0.40}$ | $\underline{65.24 \pm 1.50}$ | $\mathbf{90.72 \pm 0.12}$ | - | | GAE(GCN) | $66.79 \pm 1.65$ | $67.08 \pm 2.94$ | $53.02 \pm 1.39$ | $47.14 \pm 1.45$ | $18.67 \pm 1.32$ | $84.74 \pm 0.21$ | $37.07 \pm 5.07$ | | GAE(SAGE) | $55.02 \pm 4.03$ | $57.01 \pm 3.74$ | $39.66 \pm 0.72$ | $54.63 \pm 1.12$ | $16.55 \pm 2.40$ | $82.60 \pm 0.36$ | $53.90 \pm 4.74$ | | Optimized-GAE| $\underline{88.17 \pm 0.93}$ | $\underline{92.40 \pm 1.23}$ | $\mathbf{80.09 \pm 1.72}$ | $\underline{66.11 \pm 0.35}$ | $\mathbf{78.41 \pm 0.83}$ | $\underline{88.74 \pm 0.06}$ | $\mathbf{94.43 \pm 0.57}$ |
The code is based on the DGL library and the OGB library. To run the code, you need to set up the environment specified in the env.yaml file:
conda env create -f env.yaml
python train_w_feat_small.py --dataset Cora --activation silu --batch_size 2048 --dropout 0.6 --hidden 1024 --lr 0.005 --maskinput --mlp_layers 4 --res --norm --num_neg 3 --optimizer adamw --prop_step 4 --model LightGCN
python train_w_feat_small.py --dataset CiteSeer --activation relu --batch_size 4096 --dropout 0.6 --hidden 1024 --lr 0.001 --maskinput --norm --prop_step 4 --num_neg 1
train_w_feat_small.py --dataset PubMed --activation gelu --batch_size 4096 --dropout 0.4 --exp --hidden 512 --lr 0.001 --maskinput --mlp_layers 2 --norm --num_neg 3 --prop_step 2 --model LightGCN
Below we give the commands to run the code on the four datasets in the OGB benchmark.
python train_wo_feat.py --dataset ogbl-ddi --lr 0.001 --hidden 1024 --batch_size 8192 --dropout 0.6 --num_neg 1 --epochs 500 --prop_step 2 --metric hits@20 --residual 0.1 --maskinput --mlp_layers 8 --mlp_res --emb_dim 1024
python collab.py --dataset ogbl-collab --lr 0.0004 --emb_hidden 0 --hidden 1024 --batch_size 16384 --dropout 0.2 --num_neg 3 --epoch 500 --prop_step 4 --metric hits@50 --mlp_layers 5 --res --norm --dp4norm 0.2 --scale
python train_wo_feat.py --dataset ogbl-ppa --lr 0.001 --hidden 512 --batch_size 65536 --dropout 0.2 --num_neg 3 --epoch 800 --prop_step 2 --metric hits@100 --residual 0.1 --mlp_layers 5 --mlp_res --emb_dim 512
python citation.py --dataset ogbl-citation2 --lr 0.0003 --clip_norm 1 --emb_hidden 256 --hidden 256 --batch_size 65536 --dropout 0.2 --num_neg 3 --epochs 200 --prop_step 3 --metric MRR --norm --dp4norm 0.2 --mlp_layers 5
For ogbl-citation2 dataset, you need a GPU with at least 40GB memory.
Owner
- Login: yiming421
- Kind: user
- Repositories: 1
- Profile: https://github.com/yiming421
Citation (citation.py)
import itertools
import os
os.environ['DGLBACKEND'] = 'pytorch'
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import psutil
from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
from dgl.dataloading.negative_sampler import GlobalUniform
from torch.utils.data import DataLoader
import tqdm
import argparse
from loss import auc_loss, hinge_auc_loss, log_rank_loss
from model import Hadamard_MLPPredictor, GCN_with_feature, DotPredictor, LightGCN
from unified_model import UnifiedGNN
import time
import wandb
def log_memory_usage(stage, device, verbose=True):
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated(device) / 1024**3
cached = torch.cuda.memory_reserved(device) / 1024**3
if verbose:
print(f"[{stage}] GPU Memory - Allocated: {allocated:.3f}GB, Cached: {cached:.3f}GB")
return allocated, cached
process = psutil.Process(os.getpid())
ram_usage = process.memory_info().rss / 1024**3
if verbose:
print(f"[{stage}] RAM Usage: {ram_usage:.3f}GB")
return ram_usage
def get_memory_stats(device):
"""Get current memory statistics without logging"""
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated(device) / 1024**3
cached = torch.cuda.memory_reserved(device) / 1024**3
return allocated, cached
return 0, 0
def clear_gpu_cache(stage="", device=0):
"""Clear GPU cache and log memory reduction"""
if torch.cuda.is_available():
cached_before = torch.cuda.memory_reserved(device) / 1024**3
torch.cuda.empty_cache()
cached_after = torch.cuda.memory_reserved(device) / 1024**3
cache_freed = cached_before - cached_after
if cache_freed > 0.01: # Only log if significant memory freed
print(f"[CACHE CLEAR{' ' + stage if stage else ''}] Freed {cache_freed:.3f}GB cached memory ({cached_before:.3f}GB → {cached_after:.3f}GB)")
return cache_freed
return 0
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def parse():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", default='ogbl-citation2', choices=['ogbl-citation2'], type=str)
parser.add_argument("--lr", default=0.01, type=float)
parser.add_argument("--prop_step", default=8, type=int)
parser.add_argument("--emb_hidden", default=64, type=int)
parser.add_argument("--hidden", default=64, type=int)
parser.add_argument("--batch_size", default=8192, type=int)
parser.add_argument("--dropout", default=0.05, type=float)
parser.add_argument("--num_neg", default=1, type=int)
parser.add_argument("--epochs", default=50, type=int)
parser.add_argument("--interval", default=100, type=int)
parser.add_argument("--step_lr_decay", type=str2bool, default=True)
parser.add_argument("--metric", default='mrr', type=str)
parser.add_argument("--gpu", default=0, type=int)
parser.add_argument("--relu", type=str2bool, default=False)
parser.add_argument("--seed", default=0, type=int)
parser.add_argument("--model", default='UnifiedGNN', choices=['GCN', 'GCN_with_MLP', 'GCN_no_para', 'LightGCN', 'UnifiedGNN'], type=str)
parser.add_argument("--maskinput", type=str2bool, default=False)
parser.add_argument("--norm", type=str2bool, default=False)
parser.add_argument("--dp4norm", default=0, type=float)
parser.add_argument("--dpe", default=0, type=float)
parser.add_argument("--drop_edge", type=str2bool, default=False)
parser.add_argument("--loss", default='bce', choices=['bce', 'auc', 'hauc', 'rank'], type=str)
parser.add_argument("--residual", default=0, type=float)
parser.add_argument("--mlp_layers", default=2, type=int)
parser.add_argument("--pred", default='mlp', type=str)
parser.add_argument("--res", type=str2bool, default=False)
parser.add_argument("--conv", default='GCN', type=str)
parser.add_argument('--alpha', default=0.5, type=float)
parser.add_argument('--exp', type=str2bool, default=False)
parser.add_argument('--scale', type=str2bool, default=False)
parser.add_argument('--linear', type=str2bool, default=False)
parser.add_argument('--clip_norm', default=1.0, type=float)
# UnifiedGNN specific parameters
parser.add_argument('--unified_model_type', default='gcn', choices=['gcn', 'lightgcn'], type=str)
parser.add_argument('--multilayer', type=str2bool, default=False)
parser.add_argument('--no_parameters', type=str2bool, default=False)
parser.add_argument('--input_norm', type=str2bool, default=False)
parser.add_argument('--gin_aggr', default='sum', choices=['sum', 'mean', 'max'], type=str)
# Advanced optimization
parser.add_argument('--weight_decay', default=0.0, type=float)
parser.add_argument('--scheduler', default='none', choices=['none', 'cosine', 'step'], type=str)
parser.add_argument('--optimizer', default='adam', choices=['adam', 'adamw'], type=str)
parser.add_argument('--use_amp', type=str2bool, default=False, help='Enable Automatic Mixed Precision training')
parser.add_argument('--emb_only', type=str2bool, default=False, help='Use only learnable embeddings without original features')
parser.add_argument('--mlp_res', type=str2bool, default=False, help='Use residual connections in MLP predictor')
args = parser.parse_args()
return args
args = parse()
# Force norm=True when gin_aggr='sum'
if args.gin_aggr == 'sum':
args.norm = True
print(f"[INFO] Forcing norm=True because gin_aggr='sum'")
print(args)
wandb.init(project='Refined-GAE', config=args)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
dgl.seed(args.seed)
def eval_hits(y_pred_pos, y_pred_neg, K):
'''
compute Hits@K
For each positive target node, the negative target nodes are the same.
y_pred_neg is an array.
rank y_pred_pos[i] against y_pred_neg for each i
'''
if len(y_pred_neg) < K:
return {'hits@{}'.format(K): 1.}
kth_score_in_negative_edges = torch.topk(y_pred_neg, K)[0][-1]
hitsK = float(torch.sum(y_pred_pos > kth_score_in_negative_edges).cpu()) / len(y_pred_pos)
return {'hits@{}'.format(K): hitsK}
def adjustlr(optimizer, decay_ratio, lr):
lr_ = lr * max(1 - decay_ratio, 0.0001)
for param_group in optimizer.param_groups:
param_group['lr'] = lr_
def train(model, g, train_pos_edge, optimizer, neg_sampler, pred, emb=None, scaler=None):
model.train()
pred.train()
train_start_time = time.time()
log_memory_usage("train_start", args.gpu)
dataloader = DataLoader(range(train_pos_edge.size(0)), args.batch_size, shuffle=True)
total_loss = 0
peak_memory = 0
if args.maskinput:
mask = torch.ones(train_pos_edge.size(0), dtype=torch.bool)
for batch_idx, edge_index in enumerate(tqdm.tqdm(dataloader)):
if args.emb_only:
xemb = emb.weight if emb is not None else g.ndata['feat']
else:
xemb = torch.cat((g.ndata['feat'], emb.weight), dim=1) if emb is not None else g.ndata['feat']
# Forward pass with AMP autocast
with autocast(enabled=(scaler is not None)):
if args.maskinput:
mask[edge_index] = 0
tei = train_pos_edge[mask]
src, dst = tei.t()
re_tei = torch.stack((dst, src), dim=0).t()
tei = torch.cat((tei, re_tei), dim=0)
g_mask = dgl.graph((tei[:, 0], tei[:, 1]), num_nodes=g.num_nodes())
g_mask = dgl.add_self_loop(g_mask)
h = model(g_mask, xemb)
mask[edge_index] = 1
else:
h = model(g, xemb)
pos_edge = train_pos_edge[edge_index]
neg_train_edge = neg_sampler(g, pos_edge.t()[0])
neg_train_edge = torch.stack(neg_train_edge, dim=0)
neg_train_edge = neg_train_edge.t()
neg_edge = neg_train_edge
pos_score = pred(h[pos_edge[:, 0]], h[pos_edge[:, 1]])
neg_score = pred(h[neg_edge[:, 0]], h[neg_edge[:, 1]])
pos_score = pos_score.float()
neg_score = neg_score.float()
# Loss computation (outside autocast for stability)
if args.loss == 'auc':
loss = auc_loss(pos_score, neg_score, args.num_neg)
elif args.loss == 'hauc':
loss = hinge_auc_loss(pos_score, neg_score, args.num_neg)
elif args.loss == 'rank':
loss = log_rank_loss(pos_score, neg_score, args.num_neg)
else:
loss = F.binary_cross_entropy_with_logits(pos_score, torch.ones_like(pos_score)) + F.binary_cross_entropy_with_logits(neg_score, torch.zeros_like(neg_score))
# Check for NaN and fail instantly
if torch.isnan(loss).any():
print(f"[ERROR] NaN detected in loss at batch {batch_idx}!")
print(f"pos_score stats: min={pos_score.min():.6f}, max={pos_score.max():.6f}, mean={pos_score.mean():.6f}")
print(f"neg_score stats: min={neg_score.min():.6f}, max={neg_score.max():.6f}, mean={neg_score.mean():.6f}")
raise ValueError("NaN detected in loss computation!")
optimizer.zero_grad()
# Backward pass with AMP scaling
if scaler is not None:
scaler.scale(loss).backward()
else:
loss.backward()
# Gradient clipping
if scaler is not None:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm)
torch.nn.utils.clip_grad_norm_(pred.parameters(), args.clip_norm)
if emb is not None:
torch.nn.utils.clip_grad_norm_(emb.parameters(), args.clip_norm)
# Optimizer step with AMP scaling
if scaler is not None:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
total_loss += loss.item()
# Track peak memory occasionally
if batch_idx % 50 == 0:
current_memory = get_memory_stats(args.gpu)[0]
peak_memory = max(peak_memory, current_memory)
train_end_time = time.time()
train_duration = train_end_time - train_start_time
clear_gpu_cache("after_training", args.gpu)
log_memory_usage("train_end", args.gpu)
print(f"[TIMING] Training epoch: {train_duration:.3f}s ({len(dataloader) / train_duration:.1f} batches/s)")
print(f"[MEMORY] Peak GPU: {peak_memory:.3f}GB")
if scaler is not None:
print(f"[AMP] Grad scale: {scaler.get_scale():.0f}")
return total_loss / len(dataloader)
def test(model, g, pos_test_edge, neg_test_edge, evaluator, pred, emb=None, scaler=None):
model.eval()
pred.eval()
with torch.no_grad():
if args.emb_only:
xemb = emb.weight if emb is not None else g.ndata['feat']
else:
xemb = torch.cat((g.ndata['feat'], emb.weight), dim=1) if emb is not None else g.ndata['feat']
with autocast(enabled=(scaler is not None)):
h = model(g, xemb)
dataloader = DataLoader(range(pos_test_edge.size(0)), args.batch_size)
pos_score = []
for _, edge_index in enumerate(tqdm.tqdm(dataloader)):
pos_edge = pos_test_edge[edge_index]
with autocast(enabled=(scaler is not None)):
pos_pred = pred(h[pos_edge[:, 0]], h[pos_edge[:, 1]])
pos_score.append(pos_pred)
pos_score = torch.cat(pos_score, dim=0)
dataloader = DataLoader(range(neg_test_edge.size(0)), args.batch_size)
neg_score = []
for _, edge_index in enumerate(tqdm.tqdm(dataloader)):
neg_edge = neg_test_edge[edge_index]
with autocast(enabled=(scaler is not None)):
neg_pred = pred(h[neg_edge[:, 0]], h[neg_edge[:, 1]])
neg_score.append(neg_pred)
neg_score = torch.cat(neg_score, dim=0)
neg_score_multi = neg_score.view(-1, 1000)
results = {}
for k in [20, 50, 100]:
results[f'hits@{k}'] = eval_hits(pos_score, neg_score, k)[f'hits@{k}']
results[args.metric] = evaluator.eval({
'y_pred_pos': pos_score,
'y_pred_neg': neg_score_multi,
})['mrr_list'].mean().item()
clear_gpu_cache("after_test", args.gpu)
return results
def eval(model, g, pos_train_edge, pos_valid_edge, neg_valid_edge, evaluator, pred, emb=None, scaler=None):
model.eval()
pred.eval()
with torch.no_grad():
if args.emb_only:
xemb = emb.weight if emb is not None else g.ndata['feat']
else:
xemb = torch.cat((g.ndata['feat'], emb.weight), dim=1) if emb is not None else g.ndata['feat']
with autocast(enabled=(scaler is not None)):
h = model(g, xemb)
dataloader = DataLoader(range(pos_valid_edge.size(0)), args.batch_size)
pos_score = []
for _, edge_index in enumerate(tqdm.tqdm(dataloader)):
pos_edge = pos_valid_edge[edge_index]
with autocast(enabled=(scaler is not None)):
pos_pred = pred(h[pos_edge[:, 0]], h[pos_edge[:, 1]])
pos_score.append(pos_pred)
pos_score = torch.cat(pos_score, dim=0)
dataloader = DataLoader(range(neg_valid_edge.size(0)), args.batch_size)
neg_score = []
for _, edge_index in enumerate(tqdm.tqdm(dataloader)):
neg_edge = neg_valid_edge[edge_index]
with autocast(enabled=(scaler is not None)):
neg_pred = pred(h[neg_edge[:, 0]], h[neg_edge[:, 1]])
neg_score.append(neg_pred)
neg_score = torch.cat(neg_score, dim=0)
neg_score_multi = neg_score.view(-1, 1000)
valid_results = {}
for k in [20, 50, 100]:
valid_results[f'hits@{k}'] = eval_hits(pos_score, neg_score, k)[f'hits@{k}']
valid_results[args.metric] = evaluator.eval({
'y_pred_pos': pos_score,
'y_pred_neg': neg_score_multi,
})['mrr_list'].mean().item()
pos_score = []
dataloader = DataLoader(range(pos_valid_edge.size(0)), args.batch_size)
for _, edge_index in enumerate(tqdm.tqdm(dataloader)):
pos_edge = pos_train_edge[edge_index]
with autocast(enabled=(scaler is not None)):
pos_pred = pred(h[pos_edge[:, 0]], h[pos_edge[:, 1]])
pos_score.append(pos_pred)
pos_score = torch.cat(pos_score, dim=0)
train_results = {}
for k in [20, 50, 100]:
train_results[f'hits@{k}'] = eval_hits(pos_score, neg_score, k)[f'hits@{k}']
train_results[args.metric] = evaluator.eval({
'y_pred_pos': pos_score,
'y_pred_neg': neg_score_multi,
})['mrr_list'].mean().item()
clear_gpu_cache("after_eval", args.gpu)
return valid_results, train_results
# Load the dataset
dataset = DglLinkPropPredDataset(name=args.dataset)
split_edge = dataset.get_edge_split()
device = torch.device('cuda', args.gpu) if torch.cuda.is_available() else torch.device('cpu')
graph = dataset[0]
graph = dgl.add_self_loop(graph)
graph = dgl.to_bidirected(graph, copy_ndata=True)
graph = graph.to(device)
for name in ['train','valid','test']:
u=split_edge[name]["source_node"]
v=split_edge[name]["target_node"]
split_edge[name]['edge']=torch.stack((u,v),dim=0).t()
for name in ['valid','test']:
u=split_edge[name]["source_node"].repeat(1, 1000).view(-1)
v=split_edge[name]["target_node_neg"].view(-1)
split_edge[name]['edge_neg']=torch.stack((u,v),dim=0).t()
train_pos_edge = split_edge['train']['edge'].to(device)
valid_pos_edge = split_edge['valid']['edge'].to(device)
valid_neg_edge = split_edge['valid']['edge_neg'].to(device)
test_pos_edge = split_edge['test']['edge'].to(device)
test_neg_edge = split_edge['test']['edge_neg'].to(device)
if args.emb_hidden > 0:
embedding = torch.nn.Embedding(graph.num_nodes(), args.emb_hidden).to(device)
torch.nn.init.orthogonal_(embedding.weight)
else:
embedding = None
# Create negative samples for training
neg_sampler = GlobalUniform(args.num_neg)
if args.pred == 'dot':
pred = DotPredictor().to(device)
elif args.pred == 'mlp':
pred = Hadamard_MLPPredictor(args.hidden, args.dropout, args.mlp_layers, args.mlp_res, args.norm, args.scale).to(device)
else:
raise NotImplementedError
if args.emb_only:
input_dim = args.emb_hidden
else:
input_dim = graph.ndata['feat'].shape[1] + args.emb_hidden
if args.model == 'GCN':
model = GCN_with_feature(input_dim, args.hidden, args.norm, args.dp4norm, args.prop_step, args.dropout, args.residual, args.relu, args.linear, args.conv).to(device)
elif args.model == 'LightGCN':
model = LightGCN(input_dim, args.hidden, args.prop_step, args.dropout, args.alpha, args.exp, args.relu, args.norm, args.conv).to(device)
elif args.model == 'UnifiedGNN':
model = UnifiedGNN(
model_type=args.unified_model_type,
in_feats=input_dim,
h_feats=args.hidden,
prop_step=args.prop_step,
conv=args.conv,
multilayer=args.multilayer,
norm=args.norm,
relu=args.relu,
dropout=args.dropout,
residual=args.residual,
linear=args.linear,
alpha=args.alpha,
exp=args.exp,
res=args.res,
gin_aggr=args.gin_aggr,
no_parameters=args.no_parameters,
input_norm=args.input_norm,
supports_edge_weight=True
).to(device)
else:
raise NotImplementedError
parameter = itertools.chain(model.parameters(), pred.parameters())
if args.emb_hidden > 0:
parameter = itertools.chain(parameter, embedding.parameters())
# --- Optimizer ---
if args.optimizer == 'adam':
optimizer = torch.optim.Adam(parameter, lr=args.lr, weight_decay=args.weight_decay)
elif args.optimizer == 'adamw':
optimizer = torch.optim.AdamW(parameter, lr=args.lr, weight_decay=args.weight_decay)
else:
raise NotImplementedError
# --- Learning Rate Scheduler ---
if args.scheduler == 'cosine':
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=0)
elif args.scheduler == 'step':
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.epochs//3, gamma=0.5)
else:
scheduler = None
# Initialize AMP scaler if using mixed precision
scaler = GradScaler() if args.use_amp else None
evaluator = Evaluator(name=args.dataset)
best_val = 0
final_test_result = None
best_epoch = 0
losses = []
valid_list = []
test_list = []
if embedding is not None:
print(f'number of parameters: {sum(p.numel() for p in model.parameters()) + sum(p.numel() for p in pred.parameters()) + sum(p.numel() for p in embedding.parameters())}')
else:
print(f'number of parameters: {sum(p.numel() for p in model.parameters()) + sum(p.numel() for p in pred.parameters())}')
for epoch in range(args.epochs):
loss = train(model, graph, train_pos_edge, optimizer, neg_sampler, pred, embedding, scaler)
losses.append(loss)
if epoch % args.interval == 0 and args.step_lr_decay:
adjustlr(optimizer, epoch / args.epochs, args.lr)
valid_results, train_results = eval(model, graph, train_pos_edge, valid_pos_edge, valid_neg_edge, evaluator, pred, embedding, scaler)
valid_list.append(valid_results[args.metric])
for k, v in valid_results.items():
print(f"Validation {k}: {v:.4f}")
for k, v in train_results.items():
print(f"Train {k}: {v:.4f}")
test_results = test(model, graph, test_pos_edge, test_neg_edge, evaluator, pred, embedding, scaler)
test_list.append(test_results[args.metric])
for k, v in test_results.items():
print(f"Test {k}: {v:.4f}")
if valid_results[args.metric] > best_val:
best_val = valid_results[args.metric]
best_epoch = epoch
final_test_result = test_results
# Learning rate scheduling
if scheduler is not None:
scheduler.step()
if epoch - best_epoch >= 100:
break
print(f"Epoch {epoch}, Loss: {loss:.4f}, Train hit: {train_results[args.metric]:.4f}, Valid hit: {valid_results[args.metric]:.4f}, Test hit: {test_results[args.metric]:.4f}")
wandb.log({'loss': loss, 'train_hit': train_results[args.metric], 'valid_hit': valid_results[args.metric], 'test_hit': test_results[args.metric], 'lr': optimizer.param_groups[0]['lr']})
print(f"Test hit: {final_test_result[args.metric]:.4f}")
wandb.log({'best_test': final_test_result[args.metric]})
GitHub Events
Total
- Push event: 2
- Create event: 1
Last Year
- Push event: 2
- Create event: 1