tgn-st
This repository contains a PyTorch implementation of the paper TGN-ST model (Accepted to https://techno-srj.itc.edu.kh/ and will be appeared in next volume 2025).
Science Score: 67.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
✓DOI references
Found 4 DOI reference(s) in README -
✓Academic publication links
Links to: arxiv.org, plos.org -
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (10.5%) to scientific vocabulary
Repository
This repository contains a PyTorch implementation of the paper TGN-ST model (Accepted to https://techno-srj.itc.edu.kh/ and will be appeared in next volume 2025).
Basic Info
Statistics
- Stars: 4
- Watchers: 1
- Forks: 0
- Open Issues: 0
- Releases: 1
Metadata Files
README.md
TGN-ST
Temporal Graph Learning with Application to Large-Scale Flight Traffic Prediction (Hang et al., 2024)
This reporitory contains a PyTorch implementation of the paper TGN-ST (submitted to https://techno-srj.itc.edu.kh/ and under review).
Table of Contents
- TGN-ST
- Table of Contents
- 1. Introduction
- 2. Temporal Graph Learning (TGL)
- 3. Flight Network Dataset
- 1. Dataset Overviews
- 2. Exploratory Data Analysis (EDA)
- 4. Evaluation Methods
- 5. Temporal Graph Learning Pipeline
- 6. Model Selection
- 7. Temporal Graph Network with Static Time Encoder (TGN-ST)
- 7. Hyperparameter Tuning
- 8. Results
- 9. Conclusion
- 10. Citations
- 11. License
1. Introduction
Graph are everywhere!!!
Graph applications:
- Transportation networks. (i.e. @Google Maps)
- Social network (i.e. @Meta)
- Financial (i.e. Lecture)
- Computer vison (i.e. Paper)
- Natural Language Processing (i.e. Paper)
- Web applications (i.e. Web Image Search)
- Computational biology (i.e. @Deepmind)
- Recommendation systems (i.e. @Twitter, @Pinterest)
2. Temporal Graph Learning (TGL)
Systems of interaction can be modeled as graph network with entities as nodes and interaction as edges or links.
Real-world network is inherited with network change evolve over time. New subfield of machine learning is emerged as promising framework to extract strutural and temporal from dynamic graph assuming relational inductive bias and temporal inductive bias, called temporal graph learning.
An example of dynamic graph. The number attached to edges are timestamps. Visualized in LynxKite
TGL Methods Taxonomy
Approaches | GNN-based | Sequence-Based | Graph Walk Based | None | :---: | :---: | :---: | :---: |:---: Memory-based | JODIE | - | CAWN | EdgeBank** Self-attn -based | TGAT, TCL, GAT* | DyGFormer | - | - | self-attn + Memory based | DyREP, TGN | - | - | - | MLP-based | - | GraphMixer | - | - | - | None | GraphSAGE*, GCN | - | - | - |
Note:
() Static graph method
(*) Pure memorization method
3. Flight Network Dataset
1. Dataset Overviews
TGB tgbl-flight is a crowd sourced international flight network from 2019 to 2022.
Name | #Nodes | #Edges | #Steps | :---: | :---: | :---: | :---: tgbl-flight | 18,143 | 67,169,570 | 1,385
Geographically visualized of sampled flight network. Visualized LynxKite
2. Exploratory Data Analysis (EDA)
Due to large-scale dataset, any graph insight operations are computationally expensive. So, we perform graph subsampling as follow:
- Dynamic Graph Subsampling and Discretization
Graph subsampling perfrom random sampling on a given number of nodes. It is also feasible to give determistic node list or regional name (i.e. continent such as NA, AS, OC, SA, EU). It is futher discretized into a given period (i.e. daily, weekly, monthly, or yearly). The graph subsampling and discretization pipeline is illustrated as follow
Run Code:
{bash}
python tgb/datasets/dataset_scripts/tgbl_flight_neg_generator.py
Note: The negative edge sampling of the validation and test set are also generated.
The resulting sub-graph can be proceed to visualize and analyze as discussed next.
- Dynamic graph visualization
To understand the dynamic and graph structure, it is required to visualized it. Large graph poses very huge challenge in analysis and visualization due to computational resource. So, the whole network is required to perform subsampling for visualization feasible.
Code to generate .csv to input into Gephi is available at visualization/generate_gephi_viz_file.ipynb
By running Force Atlas 2 algorithm on sampled graph, with node sized by degree and colored by continent.
Airline route during low Covid-19 hit in Feb-21. Visualized in Gephi:
Airline route during high Covid-19 hit in Feb-21. Visualized in Gephi:
Note: North America: pink; European: blue; Asia: Orange; Oceane: red; South America: green.
Temporal Graph Network Video.
https://github.com/user-attachments/assets/b1a4e248-870a-4c0f-9b69-37997b5c6d1e
The Force Atlas 2 algorthm is able to group the airport based on their located region. DORD (O'Hare International Airport) has highest number of connection. During high Covid-19 cases, the network become sparser. Very few route is connected between the aiport, especially the airline that connect between the continents.
- Dynamic Graph Analysis
While topological graph visualization provide crucial structural information about the network, it is tedious to understand the evolution of the network through time. Poursafaei et al. and TGX introduce a novel technique to visualize temporal graph network.
Temporal Edge Appearance (TEA) plot illustrates the portion of repeated edges versus newly observed edges for each timestamp in a dynamic graph.
Temporal Edge Traffic (TET) plot visualizes the reocurrence pattern of edges in different dynamic networks over time.
Temporal Edge Appearance | Temporal Edge Traffic
:-------------------------:|:-------------------------:
| 
Heatmap of Node Degree over Time | Nodes vs Edge over Time
:-------------------------:|:-------------------------:
| 
Note: the last timestamp is not one month fully aggregated, so the chart is relatively low.
From TEA and TET plot, it guides how to select the model as well as the evaluation method as will be discussed next.
For complete visualization, please visit visualization/dynamic_graph_analysis.ipynb.
4. Evaluation Methods
Due to sparsity of the graph, more stringent method of evaluation is required. Mean Reciprocal Rank (MRR) is most suitable for such task as suggested by Poursafaei et al. and TGB by contrasting one positive edge vs 20 negative edges (note exist in the network). A simple negative edge sampling, MRR evaluation and calculation method is illustrated below
For a given prediction link, 50% of historical edges and 50% of new edge are randomly sampled. It is crucial to test the transductive learning ability as well as the Inductive learning ability of the TGL model.
5. Temporal Graph Learning Pipeline
As we gain insight about the data and metric to use, we ready to build a scalable pipeline for learning on temporal graph. an end-to-end TGL pipeline is build and displayed as below.
The script has been tested running under Python 3.10, with the following packages installed (along with their dependencies):
py-tgb==0.9.2
torch_geometric==2.5.2
wandb==0.17.4
torch==2.2.1+cu121
torch_cluster==1.6.3+pt22cu121
torch_scatter==2.1.2+pt22cu121
torch_sparse==0.6.18+pt22cu121
torch_spline_conv==1.2.2+pt22cu121
torchaudio==2.2.1+cu121
torchdata==0.7.1
scipy==1.13.0
triton==2.2.0
scikit-learn==1.4.2
networkx==3.2.1
onnx==1.16.1
Since we need to perform numerous number of experiments, wandb is a great package provided us to do experiment tracking and result visualization. As we are going to use torch.compile framework for FX graph compilation, we need a GPU with higher cuda compatibilyty, i.e. torch.cuda.get_device_capability() >= 7.0, torch >= 2.0 and torch_geometric >= 2.5.
6. Model Selection
After built TGL pipeline, we can run many experiments without burden the GPU and spent excessive amount of time. The resulting of experiment is illustrated as below.
Test MRR | GPU Memory
:-------------------------:|:-------------------------:
| 
Training Time | Validation Time
:-------------------------:|:-------------------------:
| 
Top Performance (MRR) | Top Efficiency (Training Speed) :-------------------------:|:-------------------------: DyGformer: 81% | Edgebank (No Training) GraphMixer: 80.57% | JODIE: 2.223s CAWN: 78.47% | TGN: 2.649s TGN: 71.01% | TCL: 5.586s JODIE: 68.95% | GraphMixer: 6.111s
TGN is selected to balance the training speed and performance.
Code for running experimentations:
- TGN model:
{bash}
python models/TGN.py --data "tgbl-flight" --num_run 3 --seed 1
- Non-TGN based model such as: {
DyGFormer,CAWN,GraphMixer,JODIE,TCL}:
{bash}
python train_link_prediction.py --dataset_name tgbl-flight --model_name DyGFormer --patch_size 1 --max_input_sequence_length 32 --num_runs 3 --gpu 0
7. Temporal Graph Network with Static Time Encoder (TGN-ST)
After model selection, we can gain insight into the selected model. onnx captures the FX graph and provides below visualization of TGN model for ease of understanding. Three main modules among 5 modules of TGN are delved using onnx. An example of graph compilation using torch.compile is illustrated in Time Encoder.
Time Encoder | Embedding | Decoder
:-------: | :-------: | :-------:
|
| 
To improve the model performance, GraphMixer suggestes to used static time encoder to imporove the model training stability.
To improve the model efficiency, torch.compile are wrapped around the embedding and link_predictor modules. torch.compile is an PyTorch API that solves the graph capture using TorchDynamo and paire with backend TorchInductor to compile into fast code such as OpenAI Triton code for training on GPU.
Futhermore, the TGN forward pass in algorithm 1 implemented by pytorch_geometric poses computational cost in the standard TGN forward pass (lines
7-10). This bottleneck arises from the loop that iterates through each negative neighbor query, memory_updater, embedding, and link_predictor for every node. To address this and reduce computational cost, we propose the TGN-ST forward pass algorithm in Algorithm 2. TGN-ST leverages the capabilities of modern GPUs with parallel processing architectures. Unlike the standard approach, TGN-ST avoids the need for individual loops for each node's negative neighbors (lines 6-9 in Algorithm 2). Instead, it exploits the GPU's ability to perform large matrix operations simultaneously.
- Code for experimenting TGN-ST:
{bash}
python models/TGN_ST.py --data "tgbl-flight" --num_run 3 --seed 1
7. Hyperparameter Tuning
We performed 200 random search experiments on wandb. The main observation are listed below: - AdamW with cosine annealing is the best traininign strategy - Standard SGD is the worst optimizer - Leaky_relu is the worst activation function - Batch size and patience number is directly impact the training speed
Hyperparameter Sweeping | Hyperparameter Sweeping Top Performers
:-------: | :-------:
| 
To trade-off performance and efficiency, the most suitable parameter are selected and evaluated on full dataset.
- Code for running hyperparameter sweep experimentations:
{bash}
python models/TGN_ST_Sweep.py --data "tgbl-flight" --num_run 3 --seed 1
8. Results
TGN-ST sets new state-of-the-art result with 72.49% MRR with 15% faster training and 5x validation time.
Training | Validation
:-------: | :-------:
|
|
The consistent improvement on both dataset provides evident for the generalizability of the pipeline and TGN-ST forward pass algorithm.
| Datasets | Model and Gains | Test MRR (%) | Training/Epoch (s) | Validation/Epoch (s) | >
|---|---|---|---|---|
| Sample | TGN | 72.94 | 2.3743 | 9.8546 |
| TGN-ST | 80.19 | 1.7763 | 2.2687 | |
| Gain | 7.25 | 0.598 (25%) | 7.586 (4.34x) | |
| Full Dataset | Vanilla TGN | 70.50 | 2800 | 9485 |
| TGN-ST | 72.49 | 2366 | 1887 | |
| Gain | 1.99 | 434 (15%) | 7598 (5.02x) |
Running Code:
{bash}
python models/TGN_ST.py --data tgbl-flight --num_run 3 --seed 1 --optimizer adamw --bs 288 --lr 0.00072 --t_0 14 --t_mult 4 --wd 0.01292
9. Conclusion
Research contributions are two-fold: - Scalable Framework: Establish scalable framework, performed graph subsampling, visualization, and hyperparameter tuning.
- Model Improvement: Proposed TGN-ST model with new state-of-the-art performance with signifincant training and evaluation speed boost.
10. Citations
Please consider citing the following reference when using this project.
@software{horhang_2024_13234006,
author = {HorHang},
title = {HorHang/TGN-ST: Initial Release},
month = aug,
year = 2024,
publisher = {Zenodo},
version = {0.1.0},
doi = {10.5281/zenodo.13234006},
url = {https://doi.org/10.5281/zenodo.13234006}
}
11. License
Owner
- Login: HorHang
- Kind: user
- Repositories: 1
- Profile: https://github.com/HorHang
Citation (CITATION.cff)
cff-version: 1.2.0 message: "If you use this software, please cite it as below." authors: - family-names: "Hor" given-names: "Hang" orcid: "https://orcid.org/0009-0001-2418-4663" title: "TGN-ST" version: 0.1.0 doi: 10.5281/zenodo.13234006 date-released: 2024-08-06 url: "https://doi.org/10.5281/zenodo.13234006"
GitHub Events
Total
- Watch event: 2
Last Year
- Watch event: 2
Dependencies
- actions/cache v3 composite
- actions/checkout v3 composite
- actions/setup-python v4 composite
- JRubics/poetry-publish v1.17 composite
- actions/checkout v3 composite
- actions/setup-python v4 composite
- mcr.microsoft.com/devcontainers/python 3.10 build
- mkdocs ^1.4.3 develop
- mkdocs-jupyter ^0.24.1 develop
- mkdocs-material ^9.1.15 develop
- mkdocstrings-python ^1.1.2 develop
- poetry ^1.5.1 develop
- clint ^0.5.1
- pandas ^1.5.3
- python ^3.9
- requests ^2.28.2
- scikit-learn ^1.2.2
- torch-geometric ^2.3.0
- tqdm ^4.65.0