Science Score: 54.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (4.9%) to scientific vocabulary
Last synced: 6 months ago · JSON representation ·

Repository

Basic Info
  • Host: GitHub
  • Owner: yiming421
  • Language: Python
  • Default Branch: main
  • Size: 152 KB
Statistics
  • Stars: 0
  • Watchers: 0
  • Forks: 0
  • Open Issues: 0
  • Releases: 0
Created 9 months ago · Last pushed 9 months ago
Metadata Files
Readme Citation

ReadMe.md

Update: We are thrilled to announce that our paper, "Reconsidering the Performance of GAE in Link Prediction," has been accepted at the 34th ACM International Conference on Information and Knowledge Management (CIKM 2025)!

You can read the pre-print on arXiv:

https://arxiv.org/abs/2411.03845

We achieve comparable or better performance than recent models on the OGB benchmark datasets, including ogbl-ddi, ogbl-collab, ogbl-ppa, and ogbl-citation2:

| Metric | Cora | Citeseer | Pubmed | Collab | PPA | Citation2 | DDI | |---------------|---------------|---------------|---------------|---------------|---------------|---------------|---------------| | | Hits@100 | Hits@100 | Hits@100 | Hits@50 | Hits@100 | MRR | Hits@20 | | CN | $33.92 \pm 0.46$ | $29.79 \pm 0.90$ | $23.13 \pm 0.15$ | $56.44 \pm 0.00$ | $27.65 \pm 0.00$ | $51.47 \pm 0.00$ | $17.73 \pm 0.00$ | | AA | $39.85 \pm 1.34$ | $35.19 \pm 1.33$ | $27.38 \pm 0.11$ | $64.35 \pm 0.00$ | $32.45 \pm 0.00$ | $51.89 \pm 0.00$ | $18.61 \pm 0.00$ | | RA | $41.07 \pm 0.48$ | $33.56 \pm 0.17$ | $27.03 \pm 0.35$ | $64.00 \pm 0.00$ | $49.33 \pm 0.00$ | $51.98 \pm 0.00$ | $27.60 \pm 0.00$ | | SEAL | $81.71 \pm 1.30$ | $83.89 \pm 2.15$ | $75.54 \pm 1.32$ | $64.74 \pm 0.43$ | $48.80 \pm 3.16$ | $87.67 \pm 0.32$ | $30.56 \pm 3.86$ | | NBFNet | $71.65 \pm 2.27$ | $74.07 \pm 1.75$ | $58.73 \pm 1.99$ | OOM | OOM | OOM | $4.00 \pm 0.58$ | | Neo-GNN | $80.42 \pm 1.31$ | $84.67 \pm 2.16$ | $73.93 \pm 1.19$ | $57.52 \pm 0.37$ | $49.13 \pm 0.60$ | $87.26 \pm 0.84$ | $63.57 \pm 3.52$ | | BUDDY | $88.00 \pm 0.44$ | $\mathbf{92.93 \pm 0.27}$ | $74.10 \pm 0.78$ | $65.94 \pm 0.58$ | $49.85 \pm 0.20$ | $87.56 \pm 0.11$ | $78.51 \pm 1.36$ | | NCN | $\mathbf{89.05 \pm 0.96}$ | $91.56 \pm 1.43$ | $\underline{79.05 \pm 1.16}$ | $64.76 \pm 0.87$ | $61.19 \pm 0.85$ | $88.09 \pm 0.06$ | $\underline{82.32 \pm 6.10}$ | | MPLP+ | - | - | - | $\mathbf{66.99 \pm 0.40}$ | $\underline{65.24 \pm 1.50}$ | $\mathbf{90.72 \pm 0.12}$ | - | | GAE(GCN) | $66.79 \pm 1.65$ | $67.08 \pm 2.94$ | $53.02 \pm 1.39$ | $47.14 \pm 1.45$ | $18.67 \pm 1.32$ | $84.74 \pm 0.21$ | $37.07 \pm 5.07$ | | GAE(SAGE) | $55.02 \pm 4.03$ | $57.01 \pm 3.74$ | $39.66 \pm 0.72$ | $54.63 \pm 1.12$ | $16.55 \pm 2.40$ | $82.60 \pm 0.36$ | $53.90 \pm 4.74$ | | Optimized-GAE| $\underline{88.17 \pm 0.93}$ | $\underline{92.40 \pm 1.23}$ | $\mathbf{80.09 \pm 1.72}$ | $\underline{66.11 \pm 0.35}$ | $\mathbf{78.41 \pm 0.83}$ | $\underline{88.74 \pm 0.06}$ | $\mathbf{94.43 \pm 0.57}$ |

The code is based on the DGL library and the OGB library. To run the code, you need to set up the environment specified in the env.yaml file:

conda env create -f env.yaml

python train_w_feat_small.py --dataset Cora --activation silu --batch_size 2048 --dropout 0.6 --hidden 1024 --lr 0.005 --maskinput --mlp_layers 4 --res --norm --num_neg 3 --optimizer adamw --prop_step 4 --model LightGCN

python train_w_feat_small.py --dataset CiteSeer --activation relu --batch_size 4096 --dropout 0.6 --hidden 1024 --lr 0.001 --maskinput --norm --prop_step 4 --num_neg 1

train_w_feat_small.py --dataset PubMed --activation gelu --batch_size 4096 --dropout 0.4 --exp --hidden 512 --lr 0.001 --maskinput --mlp_layers 2 --norm --num_neg 3 --prop_step 2 --model LightGCN

Below we give the commands to run the code on the four datasets in the OGB benchmark.

python train_wo_feat.py --dataset ogbl-ddi --lr 0.001 --hidden 1024 --batch_size 8192 --dropout 0.6 --num_neg 1 --epochs 500 --prop_step 2 --metric hits@20 --residual 0.1 --maskinput --mlp_layers 8 --mlp_res --emb_dim 1024

python collab.py --dataset ogbl-collab --lr 0.0004 --emb_hidden 0 --hidden 1024 --batch_size 16384 --dropout 0.2 --num_neg 3 --epoch 500 --prop_step 4 --metric hits@50 --mlp_layers 5 --res --norm --dp4norm 0.2 --scale

python train_wo_feat.py --dataset ogbl-ppa --lr 0.001 --hidden 512 --batch_size 65536 --dropout 0.2 --num_neg 3 --epoch 800 --prop_step 2 --metric hits@100 --residual 0.1 --mlp_layers 5 --mlp_res --emb_dim 512

python citation.py --dataset ogbl-citation2 --lr 0.0003 --clip_norm 1 --emb_hidden 256 --hidden 256 --batch_size 65536 --dropout 0.2 --num_neg 3 --epochs 200 --prop_step 3 --metric MRR --norm --dp4norm 0.2 --mlp_layers 5

For ogbl-citation2 dataset, you need a GPU with at least 40GB memory.

Owner

  • Login: yiming421
  • Kind: user

Citation (citation.py)

import itertools
import os
os.environ['DGLBACKEND'] = 'pytorch'

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import psutil
from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
from dgl.dataloading.negative_sampler import GlobalUniform
from torch.utils.data import DataLoader
import tqdm
import argparse
from loss import auc_loss, hinge_auc_loss, log_rank_loss
from model import Hadamard_MLPPredictor, GCN_with_feature, DotPredictor, LightGCN
from unified_model import UnifiedGNN
import time
import wandb

def log_memory_usage(stage, device, verbose=True):
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated(device) / 1024**3
        cached = torch.cuda.memory_reserved(device) / 1024**3
        if verbose:
            print(f"[{stage}] GPU Memory - Allocated: {allocated:.3f}GB, Cached: {cached:.3f}GB")
        return allocated, cached
    
    process = psutil.Process(os.getpid())
    ram_usage = process.memory_info().rss / 1024**3
    if verbose:
        print(f"[{stage}] RAM Usage: {ram_usage:.3f}GB")
    return ram_usage

def get_memory_stats(device):
    """Get current memory statistics without logging"""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated(device) / 1024**3
        cached = torch.cuda.memory_reserved(device) / 1024**3
        return allocated, cached
    return 0, 0

def clear_gpu_cache(stage="", device=0):
    """Clear GPU cache and log memory reduction"""
    if torch.cuda.is_available():
        cached_before = torch.cuda.memory_reserved(device) / 1024**3
        torch.cuda.empty_cache()
        cached_after = torch.cuda.memory_reserved(device) / 1024**3
        cache_freed = cached_before - cached_after
        if cache_freed > 0.01:  # Only log if significant memory freed
            print(f"[CACHE CLEAR{' ' + stage if stage else ''}] Freed {cache_freed:.3f}GB cached memory ({cached_before:.3f}GB → {cached_after:.3f}GB)")
        return cache_freed
    return 0

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.') 

def parse():
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", default='ogbl-citation2', choices=['ogbl-citation2'], type=str)
    parser.add_argument("--lr", default=0.01, type=float)
    parser.add_argument("--prop_step", default=8, type=int)
    parser.add_argument("--emb_hidden", default=64, type=int)
    parser.add_argument("--hidden", default=64, type=int)
    parser.add_argument("--batch_size", default=8192, type=int)
    parser.add_argument("--dropout", default=0.05, type=float)
    parser.add_argument("--num_neg", default=1, type=int)
    parser.add_argument("--epochs", default=50, type=int)
    parser.add_argument("--interval", default=100, type=int)
    parser.add_argument("--step_lr_decay", type=str2bool, default=True)
    parser.add_argument("--metric", default='mrr', type=str)
    parser.add_argument("--gpu", default=0, type=int)
    parser.add_argument("--relu", type=str2bool, default=False)
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--model", default='UnifiedGNN', choices=['GCN', 'GCN_with_MLP', 'GCN_no_para', 'LightGCN', 'UnifiedGNN'], type=str)
    parser.add_argument("--maskinput", type=str2bool, default=False)
    parser.add_argument("--norm", type=str2bool, default=False)
    parser.add_argument("--dp4norm", default=0, type=float)
    parser.add_argument("--dpe", default=0, type=float)
    parser.add_argument("--drop_edge", type=str2bool, default=False)
    parser.add_argument("--loss", default='bce', choices=['bce', 'auc', 'hauc', 'rank'], type=str)
    parser.add_argument("--residual", default=0, type=float)
    parser.add_argument("--mlp_layers", default=2, type=int)
    parser.add_argument("--pred", default='mlp', type=str)
    parser.add_argument("--res", type=str2bool, default=False)
    parser.add_argument("--conv", default='GCN', type=str)
    parser.add_argument('--alpha', default=0.5, type=float)
    parser.add_argument('--exp', type=str2bool, default=False)
    parser.add_argument('--scale', type=str2bool, default=False)
    parser.add_argument('--linear', type=str2bool, default=False)
    parser.add_argument('--clip_norm', default=1.0, type=float)
    
    # UnifiedGNN specific parameters
    parser.add_argument('--unified_model_type', default='gcn', choices=['gcn', 'lightgcn'], type=str)
    parser.add_argument('--multilayer', type=str2bool, default=False)
    parser.add_argument('--no_parameters', type=str2bool, default=False)
    parser.add_argument('--input_norm', type=str2bool, default=False)
    parser.add_argument('--gin_aggr', default='sum', choices=['sum', 'mean', 'max'], type=str)
    
    # Advanced optimization
    parser.add_argument('--weight_decay', default=0.0, type=float)
    parser.add_argument('--scheduler', default='none', choices=['none', 'cosine', 'step'], type=str)
    parser.add_argument('--optimizer', default='adam', choices=['adam', 'adamw'], type=str)
    parser.add_argument('--use_amp', type=str2bool, default=False, help='Enable Automatic Mixed Precision training')
    parser.add_argument('--emb_only', type=str2bool, default=False, help='Use only learnable embeddings without original features')
    parser.add_argument('--mlp_res', type=str2bool, default=False, help='Use residual connections in MLP predictor')

    args = parser.parse_args()
    return args

args = parse()

# Force norm=True when gin_aggr='sum'
if args.gin_aggr == 'sum':
    args.norm = True
    print(f"[INFO] Forcing norm=True because gin_aggr='sum'")

print(args)
wandb.init(project='Refined-GAE', config=args)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
dgl.seed(args.seed)

def eval_hits(y_pred_pos, y_pred_neg, K):
    '''
        compute Hits@K
        For each positive target node, the negative target nodes are the same.

        y_pred_neg is an array.
        rank y_pred_pos[i] against y_pred_neg for each i
    '''

    if len(y_pred_neg) < K:
        return {'hits@{}'.format(K): 1.}

    kth_score_in_negative_edges = torch.topk(y_pred_neg, K)[0][-1]
    hitsK = float(torch.sum(y_pred_pos > kth_score_in_negative_edges).cpu()) / len(y_pred_pos)

    return {'hits@{}'.format(K): hitsK}

def adjustlr(optimizer, decay_ratio, lr):
    lr_ = lr * max(1 - decay_ratio, 0.0001)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr_

def train(model, g, train_pos_edge, optimizer, neg_sampler, pred, emb=None, scaler=None):
    model.train()
    pred.train()

    train_start_time = time.time()
    log_memory_usage("train_start", args.gpu)
    
    dataloader = DataLoader(range(train_pos_edge.size(0)), args.batch_size, shuffle=True)
    total_loss = 0
    peak_memory = 0
    
    if args.maskinput:
        mask = torch.ones(train_pos_edge.size(0), dtype=torch.bool)
    for batch_idx, edge_index in enumerate(tqdm.tqdm(dataloader)):
        if args.emb_only:
            xemb = emb.weight if emb is not None else g.ndata['feat']
        else:
            xemb = torch.cat((g.ndata['feat'], emb.weight), dim=1) if emb is not None else g.ndata['feat']
        
        # Forward pass with AMP autocast
        with autocast(enabled=(scaler is not None)):
            if args.maskinput:
                mask[edge_index] = 0
                tei = train_pos_edge[mask]
                src, dst = tei.t()
                re_tei = torch.stack((dst, src), dim=0).t()
                tei = torch.cat((tei, re_tei), dim=0)
                g_mask = dgl.graph((tei[:, 0], tei[:, 1]), num_nodes=g.num_nodes())
                g_mask = dgl.add_self_loop(g_mask)
                h = model(g_mask, xemb)
                mask[edge_index] = 1
            else:
                h = model(g, xemb)

            pos_edge = train_pos_edge[edge_index]
            neg_train_edge = neg_sampler(g, pos_edge.t()[0])
            neg_train_edge = torch.stack(neg_train_edge, dim=0)
            neg_train_edge = neg_train_edge.t()
            neg_edge = neg_train_edge
            pos_score = pred(h[pos_edge[:, 0]], h[pos_edge[:, 1]])
            neg_score = pred(h[neg_edge[:, 0]], h[neg_edge[:, 1]])

        pos_score = pos_score.float()
        neg_score = neg_score.float()
        # Loss computation (outside autocast for stability)
        if args.loss == 'auc':
            loss = auc_loss(pos_score, neg_score, args.num_neg)
        elif args.loss == 'hauc':
            loss = hinge_auc_loss(pos_score, neg_score, args.num_neg)
        elif args.loss == 'rank':
            loss = log_rank_loss(pos_score, neg_score, args.num_neg)
        else:
            loss = F.binary_cross_entropy_with_logits(pos_score, torch.ones_like(pos_score)) + F.binary_cross_entropy_with_logits(neg_score, torch.zeros_like(neg_score))
        
        # Check for NaN and fail instantly
        if torch.isnan(loss).any():
            print(f"[ERROR] NaN detected in loss at batch {batch_idx}!")
            print(f"pos_score stats: min={pos_score.min():.6f}, max={pos_score.max():.6f}, mean={pos_score.mean():.6f}")
            print(f"neg_score stats: min={neg_score.min():.6f}, max={neg_score.max():.6f}, mean={neg_score.mean():.6f}")
            raise ValueError("NaN detected in loss computation!")
        
        optimizer.zero_grad()
        
        # Backward pass with AMP scaling
        if scaler is not None:
            scaler.scale(loss).backward()
        else:
            loss.backward()
        
        # Gradient clipping
        if scaler is not None:
            scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm)
        torch.nn.utils.clip_grad_norm_(pred.parameters(), args.clip_norm)
        if emb is not None:
            torch.nn.utils.clip_grad_norm_(emb.parameters(), args.clip_norm)
        
        # Optimizer step with AMP scaling
        if scaler is not None:
            scaler.step(optimizer)
            scaler.update()
        else:
            optimizer.step()
            
        total_loss += loss.item()
        
        # Track peak memory occasionally
        if batch_idx % 50 == 0:
            current_memory = get_memory_stats(args.gpu)[0]
            peak_memory = max(peak_memory, current_memory)
    
    train_end_time = time.time()
    train_duration = train_end_time - train_start_time
    
    clear_gpu_cache("after_training", args.gpu)
    log_memory_usage("train_end", args.gpu)
    
    print(f"[TIMING] Training epoch: {train_duration:.3f}s ({len(dataloader) / train_duration:.1f} batches/s)")
    print(f"[MEMORY] Peak GPU: {peak_memory:.3f}GB")
    if scaler is not None:
        print(f"[AMP] Grad scale: {scaler.get_scale():.0f}")

    return total_loss / len(dataloader)

def test(model, g, pos_test_edge, neg_test_edge, evaluator, pred, emb=None, scaler=None):
    model.eval()
    pred.eval()

    with torch.no_grad():
        if args.emb_only:
            xemb = emb.weight if emb is not None else g.ndata['feat']
        else:
            xemb = torch.cat((g.ndata['feat'], emb.weight), dim=1) if emb is not None else g.ndata['feat']
        with autocast(enabled=(scaler is not None)):
            h = model(g, xemb)
        dataloader = DataLoader(range(pos_test_edge.size(0)), args.batch_size)
        pos_score = []
        for _, edge_index in enumerate(tqdm.tqdm(dataloader)):
            pos_edge = pos_test_edge[edge_index]
            with autocast(enabled=(scaler is not None)):
                pos_pred = pred(h[pos_edge[:, 0]], h[pos_edge[:, 1]])
            pos_score.append(pos_pred)
        pos_score = torch.cat(pos_score, dim=0)
        dataloader = DataLoader(range(neg_test_edge.size(0)), args.batch_size)
        neg_score = []
        for _, edge_index in enumerate(tqdm.tqdm(dataloader)):
            neg_edge = neg_test_edge[edge_index]
            with autocast(enabled=(scaler is not None)):
                neg_pred = pred(h[neg_edge[:, 0]], h[neg_edge[:, 1]])
            neg_score.append(neg_pred)
        neg_score = torch.cat(neg_score, dim=0)
        neg_score_multi = neg_score.view(-1, 1000)
        results = {}
        for k in [20, 50, 100]:
            results[f'hits@{k}'] = eval_hits(pos_score, neg_score, k)[f'hits@{k}']
        results[args.metric] = evaluator.eval({
            'y_pred_pos': pos_score,
            'y_pred_neg': neg_score_multi,
        })['mrr_list'].mean().item()
    
    clear_gpu_cache("after_test", args.gpu)
    return results

def eval(model, g, pos_train_edge, pos_valid_edge, neg_valid_edge, evaluator, pred, emb=None, scaler=None):
    model.eval()
    pred.eval()

    with torch.no_grad():
        if args.emb_only:
            xemb = emb.weight if emb is not None else g.ndata['feat']
        else:
            xemb = torch.cat((g.ndata['feat'], emb.weight), dim=1) if emb is not None else g.ndata['feat']
        with autocast(enabled=(scaler is not None)):
            h = model(g, xemb)
        dataloader = DataLoader(range(pos_valid_edge.size(0)), args.batch_size)
        pos_score = []
        for _, edge_index in enumerate(tqdm.tqdm(dataloader)):
            pos_edge = pos_valid_edge[edge_index]
            with autocast(enabled=(scaler is not None)):
                pos_pred = pred(h[pos_edge[:, 0]], h[pos_edge[:, 1]])
            pos_score.append(pos_pred)
        pos_score = torch.cat(pos_score, dim=0)
        dataloader = DataLoader(range(neg_valid_edge.size(0)), args.batch_size)
        neg_score = []
        for _, edge_index in enumerate(tqdm.tqdm(dataloader)):
            neg_edge = neg_valid_edge[edge_index]
            with autocast(enabled=(scaler is not None)):
                neg_pred = pred(h[neg_edge[:, 0]], h[neg_edge[:, 1]])
            neg_score.append(neg_pred)
        neg_score = torch.cat(neg_score, dim=0)
        neg_score_multi = neg_score.view(-1, 1000)
        valid_results = {}
        for k in [20, 50, 100]:
            valid_results[f'hits@{k}'] = eval_hits(pos_score, neg_score, k)[f'hits@{k}']
        valid_results[args.metric] = evaluator.eval({
            'y_pred_pos': pos_score,
            'y_pred_neg': neg_score_multi,
        })['mrr_list'].mean().item()
        pos_score = []
        dataloader = DataLoader(range(pos_valid_edge.size(0)), args.batch_size)
        for _, edge_index in enumerate(tqdm.tqdm(dataloader)):
            pos_edge = pos_train_edge[edge_index]
            with autocast(enabled=(scaler is not None)):
                pos_pred = pred(h[pos_edge[:, 0]], h[pos_edge[:, 1]])
            pos_score.append(pos_pred)
        pos_score = torch.cat(pos_score, dim=0)
        train_results = {}
        for k in [20, 50, 100]:
            train_results[f'hits@{k}'] = eval_hits(pos_score, neg_score, k)[f'hits@{k}']
        train_results[args.metric] = evaluator.eval({
            'y_pred_pos': pos_score,
            'y_pred_neg': neg_score_multi,
        })['mrr_list'].mean().item()

    clear_gpu_cache("after_eval", args.gpu)
    return valid_results, train_results

# Load the dataset
dataset = DglLinkPropPredDataset(name=args.dataset)
split_edge = dataset.get_edge_split()

device = torch.device('cuda', args.gpu) if torch.cuda.is_available() else torch.device('cpu')

graph = dataset[0]
    
graph = dgl.add_self_loop(graph)
graph = dgl.to_bidirected(graph, copy_ndata=True)
graph = graph.to(device)

for name in ['train','valid','test']:
    u=split_edge[name]["source_node"]
    v=split_edge[name]["target_node"]
    split_edge[name]['edge']=torch.stack((u,v),dim=0).t()
for name in ['valid','test']:
    u=split_edge[name]["source_node"].repeat(1, 1000).view(-1)
    v=split_edge[name]["target_node_neg"].view(-1)
    split_edge[name]['edge_neg']=torch.stack((u,v),dim=0).t()   

train_pos_edge = split_edge['train']['edge'].to(device)
valid_pos_edge = split_edge['valid']['edge'].to(device)
valid_neg_edge = split_edge['valid']['edge_neg'].to(device)
test_pos_edge = split_edge['test']['edge'].to(device)
test_neg_edge = split_edge['test']['edge_neg'].to(device)

if args.emb_hidden > 0:
    embedding = torch.nn.Embedding(graph.num_nodes(), args.emb_hidden).to(device)
    torch.nn.init.orthogonal_(embedding.weight)
else:
    embedding = None

# Create negative samples for training
neg_sampler = GlobalUniform(args.num_neg)

if args.pred == 'dot':
    pred = DotPredictor().to(device)
elif args.pred == 'mlp':
    pred = Hadamard_MLPPredictor(args.hidden, args.dropout, args.mlp_layers, args.mlp_res, args.norm, args.scale).to(device)
else:
    raise NotImplementedError

if args.emb_only:
    input_dim = args.emb_hidden
else:
    input_dim = graph.ndata['feat'].shape[1] + args.emb_hidden

if args.model == 'GCN':
    model = GCN_with_feature(input_dim, args.hidden, args.norm, args.dp4norm, args.prop_step, args.dropout, args.residual, args.relu, args.linear, args.conv).to(device)
elif args.model == 'LightGCN':
    model = LightGCN(input_dim, args.hidden, args.prop_step, args.dropout, args.alpha, args.exp, args.relu, args.norm, args.conv).to(device)
elif args.model == 'UnifiedGNN':
    model = UnifiedGNN(
        model_type=args.unified_model_type,
        in_feats=input_dim,
        h_feats=args.hidden,
        prop_step=args.prop_step,
        conv=args.conv,
        multilayer=args.multilayer,
        norm=args.norm,
        relu=args.relu,
        dropout=args.dropout,
        residual=args.residual,
        linear=args.linear,
        alpha=args.alpha,
        exp=args.exp,
        res=args.res,
        gin_aggr=args.gin_aggr,
        no_parameters=args.no_parameters,
        input_norm=args.input_norm,
        supports_edge_weight=True
    ).to(device)
else:
    raise NotImplementedError

parameter = itertools.chain(model.parameters(), pred.parameters())
if args.emb_hidden > 0:
    parameter = itertools.chain(parameter, embedding.parameters())

# --- Optimizer ---
if args.optimizer == 'adam':
    optimizer = torch.optim.Adam(parameter, lr=args.lr, weight_decay=args.weight_decay)
elif args.optimizer == 'adamw':
    optimizer = torch.optim.AdamW(parameter, lr=args.lr, weight_decay=args.weight_decay)
else:
    raise NotImplementedError

# --- Learning Rate Scheduler ---
if args.scheduler == 'cosine':
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=0)
elif args.scheduler == 'step':
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.epochs//3, gamma=0.5)
else:
    scheduler = None

# Initialize AMP scaler if using mixed precision
scaler = GradScaler() if args.use_amp else None

evaluator = Evaluator(name=args.dataset)

best_val = 0
final_test_result = None
best_epoch = 0

losses = []
valid_list = []
test_list = []

if embedding is not None:
    print(f'number of parameters: {sum(p.numel() for p in model.parameters()) + sum(p.numel() for p in pred.parameters()) + sum(p.numel() for p in embedding.parameters())}')
else:
    print(f'number of parameters: {sum(p.numel() for p in model.parameters()) + sum(p.numel() for p in pred.parameters())}')

for epoch in range(args.epochs):
    loss = train(model, graph, train_pos_edge, optimizer, neg_sampler, pred, embedding, scaler)
    losses.append(loss)
    if epoch % args.interval == 0 and args.step_lr_decay:
        adjustlr(optimizer, epoch / args.epochs, args.lr)
    valid_results, train_results = eval(model, graph, train_pos_edge, valid_pos_edge, valid_neg_edge, evaluator, pred, embedding, scaler)
    valid_list.append(valid_results[args.metric])
    for k, v in valid_results.items():
        print(f"Validation {k}: {v:.4f}")
    for k, v in train_results.items():
        print(f"Train {k}: {v:.4f}")
    test_results = test(model, graph, test_pos_edge, test_neg_edge, evaluator, pred, embedding, scaler)
    test_list.append(test_results[args.metric])
    for k, v in test_results.items():
        print(f"Test {k}: {v:.4f}")
    if valid_results[args.metric] > best_val:
        best_val = valid_results[args.metric]
        best_epoch = epoch
        final_test_result = test_results
    
    # Learning rate scheduling
    if scheduler is not None:
        scheduler.step()
        
    if epoch - best_epoch >= 100:
        break
    print(f"Epoch {epoch}, Loss: {loss:.4f}, Train hit: {train_results[args.metric]:.4f}, Valid hit: {valid_results[args.metric]:.4f}, Test hit: {test_results[args.metric]:.4f}")
    wandb.log({'loss': loss, 'train_hit': train_results[args.metric], 'valid_hit': valid_results[args.metric], 'test_hit': test_results[args.metric], 'lr': optimizer.param_groups[0]['lr']})

print(f"Test hit: {final_test_result[args.metric]:.4f}")
wandb.log({'best_test': final_test_result[args.metric]})

GitHub Events

Total
  • Push event: 2
  • Create event: 1
Last Year
  • Push event: 2
  • Create event: 1