g2cn

Code for "Graph Gaussian Convolution Networks with Concentrated Graph Filters"

https://github.com/homles11/g2cn

Science Score: 18.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
  • .zenodo.json file
  • DOI references
  • Academic publication links
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (1.9%) to scientific vocabulary
Last synced: 10 months ago · JSON representation ·

Repository

Code for "Graph Gaussian Convolution Networks with Concentrated Graph Filters"

Basic Info
  • Host: GitHub
  • Owner: homles11
  • Language: Python
  • Default Branch: main
  • Homepage:
  • Size: 43.9 KB
Statistics
  • Stars: 2
  • Watchers: 2
  • Forks: 0
  • Open Issues: 1
  • Releases: 0
Created almost 4 years ago · Last pushed over 3 years ago
Metadata Files
Readme Citation

README.md

G2CN

Code for "Graph Gaussian Convolution Networks with Concentrated Graph Filters (ICML 2022)"

Usage

We provide the hyper-parameters for different datasets in hyperparameter.txt.

To train G2CN with a hyper-parameter searching, python tuning_dgc_gaussian_band2.py

Owner

  • Name: lmj4869
  • Login: homles11
  • Kind: user

I am currently a Ph. D student at Peking University.

Citation (citation.py)

import time
import argparse
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from utils import load_citation, sgc_precompute, set_seed, gen_rand_split, accuracy
from models import get_model
# from metrics import accuracy
import pickle as pkl
from args import get_citation_args
from time import perf_counter


def train_regression(model,
                     train_features, train_labels,
                     val_features, val_labels,
                     epochs=100, weight_decay=5e-6,
                     lr=0.2, dropout=0.):

    optimizer = optim.Adam(model.parameters(), lr=lr,
                           weight_decay=weight_decay)
    t = perf_counter()
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        output = model(train_features)
        loss_train = F.cross_entropy(output, train_labels)
        loss_train.backward()
        optimizer.step()
    train_time = perf_counter()-t

    with torch.no_grad():
        model.eval()
        output = model(val_features)
        acc_val = accuracy(output, val_labels)

    return model, acc_val, train_time

def test_regression(model, test_features, test_labels):
    model.eval()
    return accuracy(model(test_features), test_labels)

def main():

    args = get_citation_args()

    if args.tuned:
        if args.model == "SGC":
            with open("{}-tuning/{}.txt".format(args.model, args.dataset), 'rb') as f:
                args.weight_decay = pkl.load(f)['weight_decay']
                print("using tuned weight decay: {}".format(args.weight_decay))
        else:
            raise NotImplemented

    # setting random seeds
    set_seed(args.seed, args.cuda)

    adj, features, labels, idx_train, idx_val, idx_test = load_citation(args.dataset, args.normalization, args.cuda)

    if args.random_split:
        split_path = f'data/data_splits/{args.dataset}_idx.pt'
        try:
            idx_train, idx_val, idx_test = torch.load(split_path)
            print('data_split loaded from', split_path)
        except:
            idx_train, idx_val, idx_test = gen_rand_split(len(raw_features), device=raw_features.device)
            torch.save([idx_train, idx_val, idx_test], split_path)
            print('gen data split and save to', split_path)

    # import pdb; pdb.set_trace()

    model = get_model(args.model, features.size(1), labels.max().item()+1, args.hidden, args.dropout, args.cuda)

    if args.model == "SGC": features, precompute_time = sgc_precompute(features, adj, args.degree)
    print("{:.4f}s".format(precompute_time))

    if args.model == "SGC":
        model, acc_val, train_time = train_regression(model, features[idx_train], labels[idx_train], features[idx_val], labels[idx_val],
                        args.epochs, args.weight_decay, args.lr, args.dropout)
        acc_test = test_regression(model, features[idx_test], labels[idx_test])

    print("Validation Accuracy: {:.4f} Test Accuracy: {:.4f}".format(acc_val, acc_test))
    print("Pre-compute time: {:.4f}s, train time: {:.4f}s, total: {:.4f}s".format(precompute_time, train_time, precompute_time+train_time))

if __name__ == '__main__':
    main()

GitHub Events

Total
  • Watch event: 1
Last Year
  • Watch event: 1