g2cn
Code for "Graph Gaussian Convolution Networks with Concentrated Graph Filters"
Science Score: 18.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
○codemeta.json file
-
○.zenodo.json file
-
○DOI references
-
○Academic publication links
-
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (1.9%) to scientific vocabulary
Last synced: 10 months ago
·
JSON representation
·
Repository
Code for "Graph Gaussian Convolution Networks with Concentrated Graph Filters"
Basic Info
Statistics
- Stars: 2
- Watchers: 2
- Forks: 0
- Open Issues: 1
- Releases: 0
Created almost 4 years ago
· Last pushed over 3 years ago
Metadata Files
Readme
Citation
README.md
G2CN
Code for "Graph Gaussian Convolution Networks with Concentrated Graph Filters (ICML 2022)"
Usage
We provide the hyper-parameters for different datasets in hyperparameter.txt.
To train G2CN with a hyper-parameter searching,
python tuning_dgc_gaussian_band2.py
Owner
- Name: lmj4869
- Login: homles11
- Kind: user
- Website: https://zero-lab-pku.github.io/personwise/limingjie/
- Repositories: 1
- Profile: https://github.com/homles11
I am currently a Ph. D student at Peking University.
Citation (citation.py)
import time
import argparse
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from utils import load_citation, sgc_precompute, set_seed, gen_rand_split, accuracy
from models import get_model
# from metrics import accuracy
import pickle as pkl
from args import get_citation_args
from time import perf_counter
def train_regression(model,
train_features, train_labels,
val_features, val_labels,
epochs=100, weight_decay=5e-6,
lr=0.2, dropout=0.):
optimizer = optim.Adam(model.parameters(), lr=lr,
weight_decay=weight_decay)
t = perf_counter()
for epoch in range(epochs):
model.train()
optimizer.zero_grad()
output = model(train_features)
loss_train = F.cross_entropy(output, train_labels)
loss_train.backward()
optimizer.step()
train_time = perf_counter()-t
with torch.no_grad():
model.eval()
output = model(val_features)
acc_val = accuracy(output, val_labels)
return model, acc_val, train_time
def test_regression(model, test_features, test_labels):
model.eval()
return accuracy(model(test_features), test_labels)
def main():
args = get_citation_args()
if args.tuned:
if args.model == "SGC":
with open("{}-tuning/{}.txt".format(args.model, args.dataset), 'rb') as f:
args.weight_decay = pkl.load(f)['weight_decay']
print("using tuned weight decay: {}".format(args.weight_decay))
else:
raise NotImplemented
# setting random seeds
set_seed(args.seed, args.cuda)
adj, features, labels, idx_train, idx_val, idx_test = load_citation(args.dataset, args.normalization, args.cuda)
if args.random_split:
split_path = f'data/data_splits/{args.dataset}_idx.pt'
try:
idx_train, idx_val, idx_test = torch.load(split_path)
print('data_split loaded from', split_path)
except:
idx_train, idx_val, idx_test = gen_rand_split(len(raw_features), device=raw_features.device)
torch.save([idx_train, idx_val, idx_test], split_path)
print('gen data split and save to', split_path)
# import pdb; pdb.set_trace()
model = get_model(args.model, features.size(1), labels.max().item()+1, args.hidden, args.dropout, args.cuda)
if args.model == "SGC": features, precompute_time = sgc_precompute(features, adj, args.degree)
print("{:.4f}s".format(precompute_time))
if args.model == "SGC":
model, acc_val, train_time = train_regression(model, features[idx_train], labels[idx_train], features[idx_val], labels[idx_val],
args.epochs, args.weight_decay, args.lr, args.dropout)
acc_test = test_regression(model, features[idx_test], labels[idx_test])
print("Validation Accuracy: {:.4f} Test Accuracy: {:.4f}".format(acc_val, acc_test))
print("Pre-compute time: {:.4f}s, train time: {:.4f}s, total: {:.4f}s".format(precompute_time, train_time, precompute_time+train_time))
if __name__ == '__main__':
main()
GitHub Events
Total
- Watch event: 1
Last Year
- Watch event: 1