tgn-st

This repository contains a PyTorch implementation of the paper TGN-ST model (Accepted to https://techno-srj.itc.edu.kh/ and will be appeared in next volume 2025).

https://github.com/horhang/tgn-st

Science Score: 67.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
    Found 4 DOI reference(s) in README
  • Academic publication links
    Links to: arxiv.org, plos.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (10.5%) to scientific vocabulary
Last synced: 10 months ago · JSON representation ·

Repository

This repository contains a PyTorch implementation of the paper TGN-ST model (Accepted to https://techno-srj.itc.edu.kh/ and will be appeared in next volume 2025).

Basic Info
  • Host: GitHub
  • Owner: HorHang
  • License: mit
  • Language: Jupyter Notebook
  • Default Branch: main
  • Homepage:
  • Size: 36.2 MB
Statistics
  • Stars: 4
  • Watchers: 1
  • Forks: 0
  • Open Issues: 0
  • Releases: 1
Created almost 2 years ago · Last pushed almost 2 years ago
Metadata Files
Readme License Citation

README.md

TGN-ST

Temporal Graph Learning with Application to Large-Scale Flight Traffic Prediction (Hang et al., 2024)

This reporitory contains a PyTorch implementation of the paper TGN-ST (submitted to https://techno-srj.itc.edu.kh/ and under review).

Table of Contents

1. Introduction

Graph are everywhere!!!

Graph applications:

  1. Transportation networks. (i.e. @Google Maps)
  2. Social network (i.e. @Meta)
  3. Financial (i.e. Lecture)
  4. Computer vison (i.e. Paper)
  5. Natural Language Processing (i.e. Paper)
  6. Web applications (i.e. Web Image Search)
  7. Computational biology (i.e. @Deepmind)
  8. Recommendation systems (i.e. @Twitter, @Pinterest)

2. Temporal Graph Learning (TGL)

Systems of interaction can be modeled as graph network with entities as nodes and interaction as edges or links.

Real-world network is inherited with network change evolve over time. New subfield of machine learning is emerged as promising framework to extract strutural and temporal from dynamic graph assuming relational inductive bias and temporal inductive bias, called temporal graph learning.

An example of dynamic graph. The number attached to edges are timestamps. Visualized in LynxKite

TGL Methods Taxonomy

Approaches | GNN-based | Sequence-Based | Graph Walk Based | None | :---: | :---: | :---: | :---: |:---: Memory-based | JODIE | - | CAWN | EdgeBank** Self-attn -based | TGAT, TCL, GAT* | DyGFormer | - | - | self-attn + Memory based | DyREP, TGN | - | - | - | MLP-based | - | GraphMixer | - | - | - | None | GraphSAGE*, GCN | - | - | - |

Note:
() Static graph method
(
*) Pure memorization method

3. Flight Network Dataset

1. Dataset Overviews

TGB tgbl-flight is a crowd sourced international flight network from 2019 to 2022.

Name | #Nodes | #Edges | #Steps | :---: | :---: | :---: | :---: tgbl-flight | 18,143 | 67,169,570 | 1,385

Geographically visualized of sampled flight network. Visualized LynxKite

2. Exploratory Data Analysis (EDA)

Due to large-scale dataset, any graph insight operations are computationally expensive. So, we perform graph subsampling as follow:

  • Dynamic Graph Subsampling and Discretization

Graph subsampling perfrom random sampling on a given number of nodes. It is also feasible to give determistic node list or regional name (i.e. continent such as NA, AS, OC, SA, EU). It is futher discretized into a given period (i.e. daily, weekly, monthly, or yearly). The graph subsampling and discretization pipeline is illustrated as follow

Run Code: {bash} python tgb/datasets/dataset_scripts/tgbl_flight_neg_generator.py

Note: The negative edge sampling of the validation and test set are also generated.

The resulting sub-graph can be proceed to visualize and analyze as discussed next.

  • Dynamic graph visualization

To understand the dynamic and graph structure, it is required to visualized it. Large graph poses very huge challenge in analysis and visualization due to computational resource. So, the whole network is required to perform subsampling for visualization feasible.

Code to generate .csv to input into Gephi is available at visualization/generate_gephi_viz_file.ipynb

By running Force Atlas 2 algorithm on sampled graph, with node sized by degree and colored by continent.

Airline route during low Covid-19 hit in Feb-21. Visualized in Gephi:

Airline route during high Covid-19 hit in Feb-21. Visualized in Gephi:

Note: North America: pink; European: blue; Asia: Orange; Oceane: red; South America: green.

Temporal Graph Network Video.

https://github.com/user-attachments/assets/b1a4e248-870a-4c0f-9b69-37997b5c6d1e

The Force Atlas 2 algorthm is able to group the airport based on their located region. DORD (O'Hare International Airport) has highest number of connection. During high Covid-19 cases, the network become sparser. Very few route is connected between the aiport, especially the airline that connect between the continents.

  • Dynamic Graph Analysis

While topological graph visualization provide crucial structural information about the network, it is tedious to understand the evolution of the network through time. Poursafaei et al. and TGX introduce a novel technique to visualize temporal graph network.

Temporal Edge Appearance (TEA) plot illustrates the portion of repeated edges versus newly observed edges for each timestamp in a dynamic graph.

Temporal Edge Traffic (TET) plot visualizes the reocurrence pattern of edges in different dynamic networks over time.

Temporal Edge Appearance | Temporal Edge Traffic :-------------------------:|:-------------------------: |

Heatmap of Node Degree over Time | Nodes vs Edge over Time :-------------------------:|:-------------------------: |

Note: the last timestamp is not one month fully aggregated, so the chart is relatively low.

From TEA and TET plot, it guides how to select the model as well as the evaluation method as will be discussed next.

For complete visualization, please visit visualization/dynamic_graph_analysis.ipynb.

4. Evaluation Methods

Due to sparsity of the graph, more stringent method of evaluation is required. Mean Reciprocal Rank (MRR) is most suitable for such task as suggested by Poursafaei et al. and TGB by contrasting one positive edge vs 20 negative edges (note exist in the network). A simple negative edge sampling, MRR evaluation and calculation method is illustrated below

For a given prediction link, 50% of historical edges and 50% of new edge are randomly sampled. It is crucial to test the transductive learning ability as well as the Inductive learning ability of the TGL model.

5. Temporal Graph Learning Pipeline

As we gain insight about the data and metric to use, we ready to build a scalable pipeline for learning on temporal graph. an end-to-end TGL pipeline is build and displayed as below.

The script has been tested running under Python 3.10, with the following packages installed (along with their dependencies):

py-tgb==0.9.2
torch_geometric==2.5.2
wandb==0.17.4
torch==2.2.1+cu121
torch_cluster==1.6.3+pt22cu121
torch_scatter==2.1.2+pt22cu121
torch_sparse==0.6.18+pt22cu121
torch_spline_conv==1.2.2+pt22cu121
torchaudio==2.2.1+cu121
torchdata==0.7.1
scipy==1.13.0
triton==2.2.0
scikit-learn==1.4.2
networkx==3.2.1
onnx==1.16.1

Since we need to perform numerous number of experiments, wandb is a great package provided us to do experiment tracking and result visualization. As we are going to use torch.compile framework for FX graph compilation, we need a GPU with higher cuda compatibilyty, i.e. torch.cuda.get_device_capability() >= 7.0, torch >= 2.0 and torch_geometric >= 2.5.

6. Model Selection

After built TGL pipeline, we can run many experiments without burden the GPU and spent excessive amount of time. The resulting of experiment is illustrated as below.

Test MRR | GPU Memory :-------------------------:|:-------------------------: |

Training Time | Validation Time :-------------------------:|:-------------------------: |

Top Performance (MRR) | Top Efficiency (Training Speed) :-------------------------:|:-------------------------: DyGformer: 81% | Edgebank (No Training) GraphMixer: 80.57% | JODIE: 2.223s CAWN: 78.47% | TGN: 2.649s TGN: 71.01% | TCL: 5.586s JODIE: 68.95% | GraphMixer: 6.111s

TGN is selected to balance the training speed and performance.

Code for running experimentations:

  • TGN model:

{bash} python models/TGN.py --data "tgbl-flight" --num_run 3 --seed 1

  • Non-TGN based model such as: {DyGFormer, CAWN, GraphMixer, JODIE, TCL}:

{bash} python train_link_prediction.py --dataset_name tgbl-flight --model_name DyGFormer --patch_size 1 --max_input_sequence_length 32 --num_runs 3 --gpu 0

7. Temporal Graph Network with Static Time Encoder (TGN-ST)

After model selection, we can gain insight into the selected model. onnx captures the FX graph and provides below visualization of TGN model for ease of understanding. Three main modules among 5 modules of TGN are delved using onnx. An example of graph compilation using torch.compile is illustrated in Time Encoder.

Time Encoder | Embedding | Decoder :-------: | :-------: | :-------: | |

To improve the model performance, GraphMixer suggestes to used static time encoder to imporove the model training stability.

To improve the model efficiency, torch.compile are wrapped around the embedding and link_predictor modules. torch.compile is an PyTorch API that solves the graph capture using TorchDynamo and paire with backend TorchInductor to compile into fast code such as OpenAI Triton code for training on GPU.

Futhermore, the TGN forward pass in algorithm 1 implemented by pytorch_geometric poses computational cost in the standard TGN forward pass (lines 7-10). This bottleneck arises from the loop that iterates through each negative neighbor query, memory_updater, embedding, and link_predictor for every node. To address this and reduce computational cost, we propose the TGN-ST forward pass algorithm in Algorithm 2. TGN-ST leverages the capabilities of modern GPUs with parallel processing architectures. Unlike the standard approach, TGN-ST avoids the need for individual loops for each node's negative neighbors (lines 6-9 in Algorithm 2). Instead, it exploits the GPU's ability to perform large matrix operations simultaneously.

  • Code for experimenting TGN-ST:

{bash} python models/TGN_ST.py --data "tgbl-flight" --num_run 3 --seed 1

7. Hyperparameter Tuning

We performed 200 random search experiments on wandb. The main observation are listed below: - AdamW with cosine annealing is the best traininign strategy - Standard SGD is the worst optimizer - Leaky_relu is the worst activation function - Batch size and patience number is directly impact the training speed

Hyperparameter Sweeping | Hyperparameter Sweeping Top Performers :-------: | :-------: |

To trade-off performance and efficiency, the most suitable parameter are selected and evaluated on full dataset.

  • Code for running hyperparameter sweep experimentations:

{bash} python models/TGN_ST_Sweep.py --data "tgbl-flight" --num_run 3 --seed 1

8. Results

TGN-ST sets new state-of-the-art result with 72.49% MRR with 15% faster training and 5x validation time.

Training | Validation :-------: | :-------: | |

The consistent improvement on both dataset provides evident for the generalizability of the pipeline and TGN-ST forward pass algorithm.

>
Datasets Model and Gains Test MRR (%) Training/Epoch (s) Validation/Epoch (s)
Sample TGN 72.94 2.3743 9.8546
TGN-ST 80.19 1.7763 2.2687
Gain 7.25 0.598 (25%) 7.586 (4.34x)
Full Dataset Vanilla TGN 70.50 2800 9485
TGN-ST 72.49 2366 1887
Gain 1.99 434 (15%) 7598 (5.02x)

Running Code:

{bash} python models/TGN_ST.py --data tgbl-flight --num_run 3 --seed 1 --optimizer adamw --bs 288 --lr 0.00072 --t_0 14 --t_mult 4 --wd 0.01292

9. Conclusion

Research contributions are two-fold: - Scalable Framework: Establish scalable framework, performed graph subsampling, visualization, and hyperparameter tuning.

  • Model Improvement: Proposed TGN-ST model with new state-of-the-art performance with signifincant training and evaluation speed boost.

10. Citations

Please consider citing the following reference when using this project.

@software{horhang_2024_13234006, author = {HorHang}, title = {HorHang/TGN-ST: Initial Release}, month = aug, year = 2024, publisher = {Zenodo}, version = {0.1.0}, doi = {10.5281/zenodo.13234006}, url = {https://doi.org/10.5281/zenodo.13234006} }

11. License

License: MIT

Owner

  • Login: HorHang
  • Kind: user

Citation (CITATION.cff)

cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- family-names: "Hor"
  given-names: "Hang"
  orcid: "https://orcid.org/0009-0001-2418-4663"
title: "TGN-ST"
version: 0.1.0
doi: 10.5281/zenodo.13234006
date-released: 2024-08-06
url: "https://doi.org/10.5281/zenodo.13234006"

GitHub Events

Total
  • Watch event: 2
Last Year
  • Watch event: 2

Dependencies

.github/workflows/mkdocs.yaml actions
  • actions/cache v3 composite
  • actions/checkout v3 composite
  • actions/setup-python v4 composite
.github/workflows/pypi.yaml actions
  • JRubics/poetry-publish v1.17 composite
  • actions/checkout v3 composite
  • actions/setup-python v4 composite
.devcontainer/Dockerfile docker
  • mcr.microsoft.com/devcontainers/python 3.10 build
pyproject.toml pypi
  • mkdocs ^1.4.3 develop
  • mkdocs-jupyter ^0.24.1 develop
  • mkdocs-material ^9.1.15 develop
  • mkdocstrings-python ^1.1.2 develop
  • poetry ^1.5.1 develop
  • clint ^0.5.1
  • pandas ^1.5.3
  • python ^3.9
  • requests ^2.28.2
  • scikit-learn ^1.2.2
  • torch-geometric ^2.3.0
  • tqdm ^4.65.0
setup.py pypi