refined-gae
Science Score: 54.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
✓Academic publication links
Links to: arxiv.org -
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (4.9%) to scientific vocabulary
Repository
Basic Info
- Host: GitHub
- Owner: GraphPKU
- Language: Python
- Default Branch: main
- Size: 156 KB
Statistics
- Stars: 6
- Watchers: 1
- Forks: 3
- Open Issues: 0
- Releases: 0
Metadata Files
ReadMe.md
Update: We are thrilled to announce that our paper, "Reconsidering the Performance of GAE in Link Prediction," has been accepted at the 34th ACM International Conference on Information and Knowledge Management (CIKM 2025)!
You can read the pre-print on arXiv:
https://arxiv.org/abs/2411.03845
We achieve comparable or better performance than recent models on the OGB benchmark datasets, including ogbl-ddi, ogbl-collab, ogbl-ppa, and ogbl-citation2:
| Metric | Cora | Citeseer | Pubmed | Collab | PPA | Citation2 | DDI | |---------------|---------------|---------------|---------------|---------------|---------------|---------------|---------------| | | Hits@100 | Hits@100 | Hits@100 | Hits@50 | Hits@100 | MRR | Hits@20 | | CN | $33.92 \pm 0.46$ | $29.79 \pm 0.90$ | $23.13 \pm 0.15$ | $56.44 \pm 0.00$ | $27.65 \pm 0.00$ | $51.47 \pm 0.00$ | $17.73 \pm 0.00$ | | AA | $39.85 \pm 1.34$ | $35.19 \pm 1.33$ | $27.38 \pm 0.11$ | $64.35 \pm 0.00$ | $32.45 \pm 0.00$ | $51.89 \pm 0.00$ | $18.61 \pm 0.00$ | | RA | $41.07 \pm 0.48$ | $33.56 \pm 0.17$ | $27.03 \pm 0.35$ | $64.00 \pm 0.00$ | $49.33 \pm 0.00$ | $51.98 \pm 0.00$ | $27.60 \pm 0.00$ | | SEAL | $81.71 \pm 1.30$ | $83.89 \pm 2.15$ | $75.54 \pm 1.32$ | $64.74 \pm 0.43$ | $48.80 \pm 3.16$ | $87.67 \pm 0.32$ | $30.56 \pm 3.86$ | | NBFNet | $71.65 \pm 2.27$ | $74.07 \pm 1.75$ | $58.73 \pm 1.99$ | OOM | OOM | OOM | $4.00 \pm 0.58$ | | Neo-GNN | $80.42 \pm 1.31$ | $84.67 \pm 2.16$ | $73.93 \pm 1.19$ | $57.52 \pm 0.37$ | $49.13 \pm 0.60$ | $87.26 \pm 0.84$ | $63.57 \pm 3.52$ | | BUDDY | $88.00 \pm 0.44$ | $\mathbf{92.93 \pm 0.27}$ | $74.10 \pm 0.78$ | $65.94 \pm 0.58$ | $49.85 \pm 0.20$ | $87.56 \pm 0.11$ | $78.51 \pm 1.36$ | | NCN | $\mathbf{89.05 \pm 0.96}$ | $91.56 \pm 1.43$ | $\underline{79.05 \pm 1.16}$ | $64.76 \pm 0.87$ | $61.19 \pm 0.85$ | $88.09 \pm 0.06$ | $\underline{82.32 \pm 6.10}$ | | MPLP+ | - | - | - | $\mathbf{66.99 \pm 0.40}$ | $\underline{65.24 \pm 1.50}$ | $\mathbf{90.72 \pm 0.12}$ | - | | GAE(GCN) | $66.79 \pm 1.65$ | $67.08 \pm 2.94$ | $53.02 \pm 1.39$ | $47.14 \pm 1.45$ | $18.67 \pm 1.32$ | $84.74 \pm 0.21$ | $37.07 \pm 5.07$ | | GAE(SAGE) | $55.02 \pm 4.03$ | $57.01 \pm 3.74$ | $39.66 \pm 0.72$ | $54.63 \pm 1.12$ | $16.55 \pm 2.40$ | $82.60 \pm 0.36$ | $53.90 \pm 4.74$ | | Optimized-GAE| $\underline{88.17 \pm 0.93}$ | $\underline{92.40 \pm 1.23}$ | $\mathbf{80.09 \pm 1.72}$ | $\underline{66.11 \pm 0.35}$ | $\mathbf{78.41 \pm 0.83}$ | $\underline{88.74 \pm 0.06}$ | $\mathbf{94.43 \pm 0.57}$ |
The code is based on the DGL library and the OGB library. To run the code, you need to set up the environment specified in the env.yaml file:
conda env create -f env.yaml
python train_w_feat_small.py --dataset Cora --activation silu --batch_size 2048 --dropout 0.6 --hidden 1024 --lr 0.005 --maskinput --mlp_layers 4 --res --norm --num_neg 3 --optimizer adamw --prop_step 4 --model LightGCN
python train_w_feat_small.py --dataset CiteSeer --activation relu --batch_size 4096 --dropout 0.6 --hidden 1024 --lr 0.001 --maskinput --norm --prop_step 4 --num_neg 1
train_w_feat_small.py --dataset PubMed --activation gelu --batch_size 4096 --dropout 0.4 --exp --hidden 512 --lr 0.001 --maskinput --mlp_layers 2 --norm --num_neg 3 --prop_step 2 --model LightGCN
Below we give the commands to run the code on the four datasets in the OGB benchmark.
python train_wo_feat.py --dataset ogbl-ddi --lr 0.001 --hidden 1024 --batch_size 8192 --dropout 0.6 --num_neg 1 --epochs 500 --prop_step 2 --metric hits@20 --residual 0.1 --maskinput --mlp_layers 8 --mlp_res --emb_dim 1024
python collab.py --dataset ogbl-collab --lr 0.0004 --emb_hidden 0 --hidden 1024 --batch_size 16384 --dropout 0.2 --num_neg 3 --epoch 500 --prop_step 4 --metric hits@50 --mlp_layers 5 --res --norm --dp4norm 0.2 --scale
python train_wo_feat.py --dataset ogbl-ppa --lr 0.001 --hidden 512 --batch_size 65536 --dropout 0.2 --num_neg 3 --epoch 800 --prop_step 2 --metric hits@100 --residual 0.1 --mlp_layers 5 --mlp_res --emb_dim 512
python citation.py --dataset ogbl-citation2 --lr 0.0003 --clip_norm 1 --emb_hidden 256 --hidden 256 --batch_size 65536 --dropout 0.2 --num_neg 3 --epochs 200 --prop_step 3 --metric MRR --norm --dp4norm 0.2 --mlp_layers 5
For ogbl-citation2 dataset, you need a GPU with at least 40GB memory.
Owner
- Name: GraphPKU
- Login: GraphPKU
- Kind: organization
- Repositories: 3
- Profile: https://github.com/GraphPKU
Citation (citation.py)
import itertools
import os
os.environ['DGLBACKEND'] = 'pytorch'
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
from dgl.dataloading.negative_sampler import GlobalUniform
from torch.utils.data import DataLoader
import tqdm
import argparse
from loss import auc_loss, hinge_auc_loss, log_rank_loss
from model import Hadamard_MLPPredictor, GCN_with_feature, DotPredictor, LightGCN
import time
import wandb
def parse():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", default='ogbl-citation2', choices=['ogbl-citation2'], type=str)
parser.add_argument("--lr", default=0.01, type=float)
parser.add_argument("--prop_step", default=8, type=int)
parser.add_argument("--emb_hidden", default=64, type=int)
parser.add_argument("--hidden", default=64, type=int)
parser.add_argument("--batch_size", default=8192, type=int)
parser.add_argument("--dropout", default=0.05, type=float)
parser.add_argument("--num_neg", default=1, type=int)
parser.add_argument("--epochs", default=50, type=int)
parser.add_argument("--interval", default=100, type=int)
parser.add_argument("--step_lr_decay", action='store_true', default=True)
parser.add_argument("--metric", default='hits@20', type=str)
parser.add_argument("--gpu", default=0, type=int)
parser.add_argument("--relu", action='store_true', default=False)
parser.add_argument("--seed", default=0, type=int)
parser.add_argument("--model", default='GCN', choices=['GCN', 'GCN_with_MLP', 'GCN_no_para', 'LightGCN'], type=str)
parser.add_argument("--maskinput", action='store_true', default=False)
parser.add_argument("--norm", action='store_true', default=False)
parser.add_argument("--dp4norm", default=0, type=float)
parser.add_argument("--dpe", default=0, type=float)
parser.add_argument("--drop_edge", action='store_true', default=False)
parser.add_argument("--loss", default='bce', choices=['bce', 'auc', 'hauc', 'rank'], type=str)
parser.add_argument("--residual", default=0, type=float)
parser.add_argument("--mlp_layers", default=2, type=int)
parser.add_argument("--pred", default='mlp', type=str)
parser.add_argument("--res", action='store_true', default=False)
parser.add_argument("--conv", default='GCN', type=str)
parser.add_argument('--alpha', default=0.5, type=float)
parser.add_argument('--exp', action='store_true', default=False)
parser.add_argument('--scale', action='store_true', default=False)
parser.add_argument('--linear', action='store_true', default=False)
parser.add_argument('--clip_norm', default=1.0, type=float)
args = parser.parse_args()
return args
args = parse()
print(args)
wandb.init(project='Refined-GAE', config=args)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
dgl.seed(args.seed)
def eval_hits(y_pred_pos, y_pred_neg, K):
'''
compute Hits@K
For each positive target node, the negative target nodes are the same.
y_pred_neg is an array.
rank y_pred_pos[i] against y_pred_neg for each i
'''
if len(y_pred_neg) < K:
return {'hits@{}'.format(K): 1.}
kth_score_in_negative_edges = torch.topk(y_pred_neg, K)[0][-1]
hitsK = float(torch.sum(y_pred_pos > kth_score_in_negative_edges).cpu()) / len(y_pred_pos)
return {'hits@{}'.format(K): hitsK}
def adjustlr(optimizer, decay_ratio, lr):
lr_ = lr * max(1 - decay_ratio, 0.0001)
for param_group in optimizer.param_groups:
param_group['lr'] = lr_
def train(model, g, train_pos_edge, optimizer, neg_sampler, pred, emb=None):
model.train()
pred.train()
dataloader = DataLoader(range(train_pos_edge.size(0)), args.batch_size, shuffle=True)
total_loss = 0
if args.maskinput:
mask = torch.ones(train_pos_edge.size(0), dtype=torch.bool)
for _, edge_index in enumerate(tqdm.tqdm(dataloader)):
xemb = torch.cat((g.ndata['feat'], emb.weight), dim=1) if emb is not None else g.ndata['feat']
if args.maskinput:
mask[edge_index] = 0
tei = train_pos_edge[mask]
src, dst = tei.t()
re_tei = torch.stack((dst, src), dim=0).t()
tei = torch.cat((tei, re_tei), dim=0)
g_mask = dgl.graph((tei[:, 0], tei[:, 1]), num_nodes=g.num_nodes())
g_mask = dgl.add_self_loop(g_mask)
h = model(g_mask, xemb)
mask[edge_index] = 1
else:
h = model(g, xemb)
pos_edge = train_pos_edge[edge_index]
neg_train_edge = neg_sampler(g, pos_edge.t()[0])
neg_train_edge = torch.stack(neg_train_edge, dim=0)
neg_train_edge = neg_train_edge.t()
neg_edge = neg_train_edge
pos_score = pred(h[pos_edge[:, 0]], h[pos_edge[:, 1]])
neg_score = pred(h[neg_edge[:, 0]], h[neg_edge[:, 1]])
if args.loss == 'auc':
loss = auc_loss(pos_score, neg_score, args.num_neg)
elif args.loss == 'hauc':
loss = hinge_auc_loss(pos_score, neg_score, args.num_neg)
elif args.loss == 'rank':
loss = log_rank_loss(pos_score, neg_score, args.num_neg)
else:
loss = F.binary_cross_entropy_with_logits(pos_score, torch.ones_like(pos_score)) + F.binary_cross_entropy_with_logits(neg_score, torch.zeros_like(neg_score))
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm)
torch.nn.utils.clip_grad_norm_(pred.parameters(), args.clip_norm)
if emb is not None:
torch.nn.utils.clip_grad_norm_(emb.parameters(), args.clip_norm)
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
def test(model, g, pos_test_edge, neg_test_edge, evaluator, pred, emb=None):
model.eval()
pred.eval()
with torch.no_grad():
xemb = torch.cat((g.ndata['feat'], emb.weight), dim=1) if emb is not None else g.ndata['feat']
h = model(g, xemb)
dataloader = DataLoader(range(pos_test_edge.size(0)), args.batch_size)
pos_score = []
for _, edge_index in enumerate(tqdm.tqdm(dataloader)):
pos_edge = pos_test_edge[edge_index]
pos_pred = pred(h[pos_edge[:, 0]], h[pos_edge[:, 1]])
pos_score.append(pos_pred)
pos_score = torch.cat(pos_score, dim=0)
dataloader = DataLoader(range(neg_test_edge.size(0)), args.batch_size)
neg_score = []
for _, edge_index in enumerate(tqdm.tqdm(dataloader)):
neg_edge = neg_test_edge[edge_index]
neg_pred = pred(h[neg_edge[:, 0]], h[neg_edge[:, 1]])
neg_score.append(neg_pred)
neg_score = torch.cat(neg_score, dim=0)
neg_score_multi = neg_score.view(-1, 1000)
results = {}
for k in [20, 50, 100]:
results[f'hits@{k}'] = eval_hits(pos_score, neg_score, k)[f'hits@{k}']
results[args.metric] = evaluator.eval({
'y_pred_pos': pos_score,
'y_pred_neg': neg_score_multi,
})['mrr_list'].mean().item()
return results
def eval(model, g, pos_train_edge, pos_valid_edge, neg_valid_edge, evaluator, pred, emb=None):
model.eval()
pred.eval()
with torch.no_grad():
xemb = torch.cat((g.ndata['feat'], emb.weight), dim=1) if emb is not None else g.ndata['feat']
h = model(g, xemb)
dataloader = DataLoader(range(pos_valid_edge.size(0)), args.batch_size)
pos_score = []
for _, edge_index in enumerate(tqdm.tqdm(dataloader)):
pos_edge = pos_valid_edge[edge_index]
pos_pred = pred(h[pos_edge[:, 0]], h[pos_edge[:, 1]])
pos_score.append(pos_pred)
pos_score = torch.cat(pos_score, dim=0)
dataloader = DataLoader(range(neg_valid_edge.size(0)), args.batch_size)
neg_score = []
for _, edge_index in enumerate(tqdm.tqdm(dataloader)):
neg_edge = neg_valid_edge[edge_index]
neg_pred = pred(h[neg_edge[:, 0]], h[neg_edge[:, 1]])
neg_score.append(neg_pred)
neg_score = torch.cat(neg_score, dim=0)
neg_score_multi = neg_score.view(-1, 1000)
valid_results = {}
for k in [20, 50, 100]:
valid_results[f'hits@{k}'] = eval_hits(pos_score, neg_score, k)[f'hits@{k}']
valid_results[args.metric] = evaluator.eval({
'y_pred_pos': pos_score,
'y_pred_neg': neg_score_multi,
})['mrr_list'].mean().item()
pos_score = []
dataloader = DataLoader(range(pos_valid_edge.size(0)), args.batch_size)
for _, edge_index in enumerate(tqdm.tqdm(dataloader)):
pos_edge = pos_train_edge[edge_index]
pos_pred = pred(h[pos_edge[:, 0]], h[pos_edge[:, 1]])
pos_score.append(pos_pred)
pos_score = torch.cat(pos_score, dim=0)
train_results = {}
for k in [20, 50, 100]:
train_results[f'hits@{k}'] = eval_hits(pos_score, neg_score, k)[f'hits@{k}']
train_results[args.metric] = evaluator.eval({
'y_pred_pos': pos_score,
'y_pred_neg': neg_score_multi,
})['mrr_list'].mean().item()
return valid_results, train_results
# Load the dataset
dataset = DglLinkPropPredDataset(name=args.dataset)
split_edge = dataset.get_edge_split()
device = torch.device('cuda', args.gpu) if torch.cuda.is_available() else torch.device('cpu')
graph = dataset[0]
graph = dgl.add_self_loop(graph)
graph = dgl.to_bidirected(graph, copy_ndata=True)
graph = graph.to(device)
for name in ['train','valid','test']:
u=split_edge[name]["source_node"]
v=split_edge[name]["target_node"]
split_edge[name]['edge']=torch.stack((u,v),dim=0).t()
for name in ['valid','test']:
u=split_edge[name]["source_node"].repeat(1, 1000).view(-1)
v=split_edge[name]["target_node_neg"].view(-1)
split_edge[name]['edge_neg']=torch.stack((u,v),dim=0).t()
train_pos_edge = split_edge['train']['edge'].to(device)
valid_pos_edge = split_edge['valid']['edge'].to(device)
valid_neg_edge = split_edge['valid']['edge_neg'].to(device)
test_pos_edge = split_edge['test']['edge'].to(device)
test_neg_edge = split_edge['test']['edge_neg'].to(device)
if args.emb_hidden > 0:
embedding = torch.nn.Embedding(graph.num_nodes(), args.emb_hidden).to(device)
torch.nn.init.orthogonal_(embedding.weight)
else:
embedding = None
# Create negative samples for training
neg_sampler = GlobalUniform(args.num_neg)
if args.pred == 'dot':
pred = DotPredictor().to(device)
elif args.pred == 'mlp':
pred = Hadamard_MLPPredictor(args.hidden, args.dropout, args.mlp_layers, args.res, args.norm, args.scale).to(device)
else:
raise NotImplementedError
input_dim = graph.ndata['feat'].shape[1] + args.emb_hidden
if args.model == 'GCN':
model = GCN_with_feature(graph.ndata['feat'].shape[1] + args.emb_hidden, args.hidden, args.norm, args.dp4norm, args.prop_step, args.dropout, args.residual, args.relu, args.linear, args.conv).to(device)
elif args.model == 'LightGCN':
model = LightGCN(graph.ndata['feat'].shape[1] + args.emb_hidden, args.hidden, args.prop_step, args.dropout, args.alpha, args.exp, args.relu, args.norm, args.conv).to(device)
parameter = itertools.chain(model.parameters(), pred.parameters())
if args.emb_hidden > 0:
parameter = itertools.chain(parameter, embedding.parameters())
optimizer = torch.optim.Adam(parameter, lr=args.lr)
evaluator = Evaluator(name=args.dataset)
best_val = 0
final_test_result = None
best_epoch = 0
losses = []
valid_list = []
test_list = []
if embedding is not None:
print(f'number of parameters: {sum(p.numel() for p in model.parameters()) + sum(p.numel() for p in pred.parameters()) + sum(p.numel() for p in embedding.parameters())}')
else:
print(f'number of parameters: {sum(p.numel() for p in model.parameters()) + sum(p.numel() for p in pred.parameters())}')
for epoch in range(args.epochs):
loss = train(model, graph, train_pos_edge, optimizer, neg_sampler, pred, embedding)
losses.append(loss)
if epoch % args.interval == 0 and args.step_lr_decay:
adjustlr(optimizer, epoch / args.epochs, args.lr)
valid_results, train_results = eval(model, graph, train_pos_edge, valid_pos_edge, valid_neg_edge, evaluator, pred, embedding)
valid_list.append(valid_results[args.metric])
for k, v in valid_results.items():
print(f"Validation {k}: {v:.4f}")
for k, v in train_results.items():
print(f"Train {k}: {v:.4f}")
test_results = test(model, graph, test_pos_edge, test_neg_edge, evaluator, pred, embedding)
test_list.append(test_results[args.metric])
for k, v in test_results.items():
print(f"Test {k}: {v:.4f}")
if valid_results[args.metric] > best_val:
best_val = valid_results[args.metric]
best_epoch = epoch
final_test_result = test_results
if epoch - best_epoch >= 100:
break
print(f"Epoch {epoch}, Loss: {loss:.4f}, Train hit: {train_results[args.metric]:.4f}, Valid hit: {valid_results[args.metric]:.4f}, Test hit: {test_results[args.metric]:.4f}")
wandb.log({'loss': loss, 'train_hit': train_results[args.metric], 'valid_hit': valid_results[args.metric], 'test_hit': test_results[args.metric]})
print(f"Test hit: {final_test_result[args.metric]:.4f}")
wandb.log({'best_test': final_test_result[args.metric]})
GitHub Events
Total
- Watch event: 7
- Member event: 1
- Push event: 4
- Fork event: 2
- Create event: 2
Last Year
- Watch event: 7
- Member event: 1
- Push event: 4
- Fork event: 2
- Create event: 2