stgm

STGM: Spatio-Temporal Graph Mixformer for Traffic Forecasting

https://github.com/mouradost/stgm

Science Score: 67.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
    Found 6 DOI reference(s) in README
  • Academic publication links
    Links to: zenodo.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (8.9%) to scientific vocabulary

Keywords

attention-mechanism gnn-model spatio-temporal-prediction traffic-forecasting
Last synced: 8 months ago · JSON representation ·

Repository

STGM: Spatio-Temporal Graph Mixformer for Traffic Forecasting

Basic Info
  • Host: GitHub
  • Owner: Mouradost
  • License: mit
  • Language: Jupyter Notebook
  • Default Branch: main
  • Homepage:
  • Size: 4.07 MB
Statistics
  • Stars: 98
  • Watchers: 3
  • Forks: 20
  • Open Issues: 5
  • Releases: 2
Topics
attention-mechanism gnn-model spatio-temporal-prediction traffic-forecasting
Created over 3 years ago · Last pushed over 2 years ago
Metadata Files
Readme License Citation

README.md

STGM: Spatio-Temporal Graph Mixformer for Traffic Forecasting

Architecture

LICENCE DOI paper-with-code-pems-bay-metr-la paper-with-code-pemsd7-m paper-with-code-pems-bay Code style: black

This is the original repository for the paper entitled " Spatio-Temporal Graph Mixformer for Traffic Forecasting ".


Abstract

Traffic forecasting is of great importance for intelligent transportation systems (ITS). Because of the intricacy implied in traffic behavior and the non-Euclidean nature of traffic data, it is challenging to give an accurate traffic prediction. Despite that previous studies considered the relationship between different nodes, the majority have relied on a static representation and failed to capture the dynamic node interactions over time. Additionally, prior studies employed RNN-based models to capture the temporal dependency. While RNNs are a popular choice for forecasting problems, they tend to be memory hungry and slow to train. Furthermore, recent studies start utilizing similarity algorithms to better express the implication of a node over the other. However, to our knowledge, none have explored the contribution of node $𝑖$’s past, over the future state of node $𝑗$. In this paper, we propose a Spatio-Temporal Graph Mixformer (STGM) network, a highly optimized model with low memory footprint. We address the aforementioned limits by utilizing a novel attention mechanism to capture the correlation between temporal and spatial dependencies. Specifically, we use convolution layers with a variable field of view for each head to capture long–short term temporal dependency. Additionally, we train an estimator model that express the contribution of a node over the desired prediction. The estimation is fed alongside a distance matrix to the attention mechanism. Meanwhile, we use a gated mechanism and a mixer layer to further select and incorporate the different perspectives. Extensive experiments show that the proposed model enjoys a performance gain compared to the baselines while maintaining the lowest parameter counts.


Performance and Visualization

Table

Prediction Sample

Table of Contents

text datasets ├─pemsbay | ├─distances.csv | ├─speed.h5 ├─metrla | ├─distances.csv | ├─speed.h5 ├─pemsd7m | ├─distances.csv | ├─speed.csv src ├─configs | ├─datasets | | ├─pemsbay.yaml | | ├─metrla.yaml | | ├─pemsd7m.yaml | ├─estimator | | ├─default.yaml | ├─model | | ├─stgm.yaml | | ├─stgm_full.yaml | ├─trainer | | ├─default.yaml | | ├─full.yaml | ├─sweeps | | ├─pemsbay.yaml | | ├─metrla.yaml | ├─paths | | ├─default.yaml | ├─log | | ├─default.yaml | ├─hydra | | ├─default.yaml | ├─config.yaml ├─helpers | ├─__init__.py | ├─dtw.py | ├─metrics.py | ├─utils.py | ├─viz.py ├─datasets | ├─__init__.py | ├─base.py | ├─pems.py | ├─metrla.py | ├─pemsd7m.py ├─models | ├─__init__.py | ├─base.py | ├─estimator.py | ├─stgm.py | ├─stgm_full.py ├─trainers | ├─__init__.py | ├─base.py | ├─default.py | ├─full.py ├─logs -> Run logs | ├─METR-LA -> Dataset name (auto-generated) | ├─<DATE> -> Starting date (auto-generated) | ├─<TIME> -> Starting time (auto-generated) | ├─<MODEL_NAME> -> Model name (auto-generated) | ├─config.yaml -> Config use for this run (auto-generated) | ├─run.log -> Run logs (auto-generated) | ├─model_weights.pt -> Latest model weights (auto-generated) | ├─model_weights_*.pt -> Model weights for a specific checkpoint (auto-generated) | ├─performance.txt -> Prediction metrics overview (auto-generated) | ├─train_history.pdf -> Train history figure (auto-generated) ├─run.py ├─inference.py ├─__init__.py .gitignore LICENSE README.md requirements.txt

Requirements

The code is built based on Python 3.11.0, PyTorch 2.0.0 and Hydra 1.3.1.

You can install PyTorch following the instruction on PyTorch website. For example:

shell pip install torch==2.0.0+cu117 torchvision==2.0.0+cu117 -f https://download.pytorch.org/whl/torch_stable.html

After ensuring that PyTorch is installed correctly, the user can clone the repo and install the remaining dependencies as follows:

shell git clone https://github.com/Mouradost/STGM.git cd ./STGM pip install -r requirements.txt

Run STGM (or your custom version)

Our implementation uses the following frameworks for logging and configuration:

  • hydra for a configurable and customizable training pipeline, you can find all the options for each module in the folder src/configs, refer to hydra docs for further instructions.
  • wandb as an online logger and logging as a local logger, in order to choose a logger you need to specify it as an argument to the run.py script as follows: shell ./run.py log.type=[wandb | default] We also provide a BaseClass for each of Trainer, Model and Dataset that can be extended to customize your run relative to your needs, please refer to Customize your run.

The following sections assume the user is situated at src/.

shell cd [PATH_TO_STGM_REPO_LOCAL_FOLDER]/src/

The similarity matrices can be pre-calculated in advance for faster training using the snippet below:

shell ./helpers/dtw.py [PATH_DATA_FOLDER]

For more options try:

shell ./helpers/dtw.py --help

Train STGM

Start the run by simply executing ./run.py or python run.py which will use the default configuration. Any parameter to alter the default configuration can be provided following the below schema.

shell ./run.py [MODULE]=$value [MODULE].[PARAM]=$value [TOP_LEVEL_PARAM]=$value

Refer to this example if it is ambiguous.

shell ./run.py trainer.epochs=100 trainer.batch_size=128 dataset=metrla device=cuda:1 model.nb_blocks=1 model=stgm log.logger=wandb

Predict using STGM

Perform a prediction using the script ./inference.py --path=[PATH_TO_SAVED_FOLDER] or python run.py --path=[PATH_TO_SAVED_FOLDER] which will use the configuration provided in the PATHTOSAVED_FOLDER.

You can find all available options using ./inference.py --help or python run.py --help.

The folder should follow the following schema:

text <SAVED_FOLDER> -> Any folder that contain the model weights and the configuation file ├─config.yaml -> Config file ├─model_weights.pt -> Model weights

Customize your run

Dataset

Training and inference can be performed on any dataset that contain the following files:

  • A signal file (any time-series data) $X \in \mathcal{R}^{N \times F}$
  • A time-steps for each signal entry $T \in \mathcal{R}$
  • An adjacency matrix $A \in \mathcal{R}^{N \times N}$

The user only needs to overwrite the load_data function in order to provide self.data, self.time_index, self.timestamps, self.adj, self.data_min, self.data_max, self.data_mean, and self.data_std.

The script for the custom dataset needs to be placed in the dataset folder as follows ./src/datasets/../[DATASET_NAME].py.

Follow the below snippet to provide the necessary data to the BaseDataset:

```python import pandas as pd from datasets.base import BaseDataset # Base dataset class

class Dataset(BaseDataset): # The datafolderpath is provided by the BaseDataset. # BaseDataset reads the path from ./src/config/paths/[defaults | user defined config].yaml def loaddata(self): # Loading the adjacency matrix self.adj = pd.readcsv( self.datafolderpath / "PEMSBAY" / "Wpemsbay.csv" ).values # [NBNODES, NBNODES] # Loading the signal data data = pd.readhdf(self.datafolderpath / "PEMSBAY" / "pems-bay.h5") self.data = data.values # [NBSAMPLES, NBNODES] # Meta-data used for normalization self.datamean, self.datastd = self.data.mean(), self.data.std() self.datamin, self.datamax = self.data.min(), self.data.max() # Timesteps inferrence self.timestamps = data.index # Pandas datetime used for visualization self.timestamps.freq = self.timestamps.inferredfreq self.timeindex = self.timetoidx(self.timestamps, freq="5min") # IDX representing time of the day and days of the week [NB_SAMPLES, 2] ```

The user need to provide a YAML config file for the custom dataset in ./src/config/dataset/. Refer to PEMS-BAY Config file for all the options.

Model

Follow the below snippet to create a custom model, please subclass the BaseModel otherwise the program won't run properly since it provides model saving mechanisms and other functionality used in other classes:

```python import logging import torch from models.base import BaseModel

class Model(BaseModel): def init( self, inchannels: int = 1, hchannels: int = 64, outchannels: int = 1, device: str = "cpu", name: str = "STGM", loglevel=logging.WARNING, args, # Used to obmit any extra args provided by the run.py or inference.py *kwargs # Used to obmit any extra kwargs provided by the run.py or inference.py ): super().init(name=name, loglevel=loglevel) self.channelslast = channelslast

    # Model definition

def forward(
    self,
    x: torch.Tensor,  # [B, L, N, F] | [B, F, N, L]
    adj: torch.Tensor,  # [N, N]
    idx: torch.Tensor = None,  # [B, 2, L]
    *args, # Used to obmit any extra args provided by the trainer
    **kwargs # Used to obmit any extra kwargs provided by the trainer
) -> torch.Tensor:
    # Permutation if necessairy
    if self.channels_last:
        x = x.transpose(1, -1).contiguous()  # [B, L, N, F] -> [B, F, N, L]
        self.logger.debug("transpose -> x: %s", x.shape)

    # Model logic

    # Permutation if necessairy
    if self.channels_last:
        x = x.transpose(1, -1).contiguous()  # [B, F', N, L] -> [B, L, N, F']
        self.logger.debug("transpose -> x: %s", x.shape)
    return x  # [B, F', N, L] | [B, L, N, F']

```

The user need to provide a YAML config file for the custom model in ./src/config/model/. Refer to STGM Config file for all the options. The use can extend the provided model config for any additional parameters to the model.

Trainer

The user only needs to overwrite train, train_step, validate, and val_step from BaseTrainer to create a custom training logic.

Follow the below snippet for an easy start 😉.

Note that there is no need for normalization since it is handled by the BaseTrainer.

```python import torch from tqdm import tqdm from tqdm.contrib.logging import loggingredirecttqdm from trainers.base import BaseTrainer

class Trainer(BaseTrainer): def valstep( self, x: torch.Tensor, y: torch.Tensor, idx: torch.Tensor, adj: torch.Tensor, sim: torch.Tensor ) -> float: # Compute the loss. with torch.autocast( devicetype="cuda", dtype=torch.float16, enabled=self.use_amp ): loss = self.loss(x=x, y=y, idx=idx, adj=adj, sim=sim) return loss.item()

@torch.no_grad()
def validate(self, val_data: torch.utils.data.DataLoader):
    for i, (idx, adj, sim, x, y) in enumerate(val_data):
        idx, adj, sim, x, y = (
            idx.to(self.device), # [B, 2, L]
            adj.to(self.device), # [B, N, N]
            sim.to(self.device), # [B, N, N]
            x.to(self.device),   # [B, L, N, F]
            y.to(self.device)    # [B, L, N, F']
        )
        # Validation step
        self.metrics["val/loss"] += self.val_step(
            x=x, y=y, idx=idx, adj=adj, sim=sim
        ) / (i + 1)

def train_step(
  self,
  x: torch.Tensor,
  y: torch.Tensor,
  idx: torch.Tensor,
  adj: torch.Tensor,
  sim: torch.Tensor
) -> float:
    # Compute the loss.
    with torch.autocast(
        device_type="cuda", dtype=torch.float16, enabled=self.use_amp
    ):
        loss = self.loss(x=x, y=y, idx=idx, adj=adj, sim=sim)
    # Before the backward pass, zero all of the network gradients
    self.opt.zero_grad(set_to_none=True)
    # Backward pass: compute gradient of the loss with respect to parameters
    self.grad_sclaer.scale(loss).backward()
    self.grad_sclaer.step(self.opt)
    self.grad_sclaer.update()
    return loss.item()


def train(
    self,
    train_data: torch.utils.data.DataLoader,
    val_data: torch.utils.data.DataLoader | None = None,
) -> dict:
    self.model.train()
    epoch_loop = tqdm(
        range(self.epochs),
        desc=f"Epoch 0/{self.epochs} train",
        leave=True,
        disable=not self.verbose,
    )
    with logging_redirect_tqdm():
        for epoch in epoch_loop:
            self.metrics["train/loss"] = 0
            self.metrics["val/loss"] = 0
            for i, (idx, adj, sim, x, y) in enumerate(train_data):
                idx, adj, sim, x, y = (
                    idx.to(self.device), # [B, 2, L]
                    adj.to(self.device), # [B, N, N]
                    sim.to(self.device), # [B, N, N]
                    x.to(self.device),   # [B, L, N, F]
                    y.to(self.device)    # [B, L, N, F']
                )
            # Train step
            self.metrics["train/loss"] += self.train_step(
                x=x, y=y, idx=idx, adj=adj, sim=sim
            ) / (i + 1)
            # Validation
            if val_data is not None:
                self.model.eval()
                self.validate(val_data=val_data)
                self.model.train()
            # Log metrics
            self.log(epoch=epoch + 1)
            epoch_loop.set_postfix(
                loss=self.metrics["train/loss"],
                loss_val=self.metrics["val/loss"],
            )
            epoch_loop.set_description_str(f"Epoch {epoch + 1}/{self.epochs} train")
    return self.history

```

The user need to provide a YAML config file for the custom trainer in ./src/config/trainer/. Refer to Trainer Config file for all the options.

Hyperparameters search

We also provide an integration of hydra combined with wandb parameter search also called sweeps. For more information and tutorials the user can check the following links:

Some sweep samples are provided under ./src/config/sweeps/*.yaml, the user can provide its own set of rules and expect similar output as shown below:

Note: The sweep shown below is an example for a simple model, the user needs to provide a sweep configuration adapted to the size of the model and the dataset for good results.

Sweep Example

The user can run the sweeps as follows:

shell wandb sweep ./src/config/sweeps/[SWEEP_CONFIG_NAME].yaml

Disclaimer

The code in this repo is a simplified and polished version of the original code used for the paper, if you encounter any bugs or a big difference between the metrics reported in the paper and your run please open an issue following the provided guidelines.

The base configurations provided in this repo are not the one used in the paper, rather a scaled down version meant to be used on a personal computer as a starting point and a proof of work, refer to the paper for the full version which we used to train our model on a server.

Citing STGM

If you find this repository useful for your work, please consider citing it using the following BibTeX entry:

bibtex @article{lablack2023spatio, title={Spatio-temporal graph mixformer for traffic forecasting}, author={Lablack, Mourad and Shen, Yanming}, journal={Expert Systems with Applications}, volume={228}, pages={120281}, year={2023}, publisher={Elsevier} }

Owner

  • Name: Lablack Mourad
  • Login: Mouradost
  • Kind: user
  • Location: Dalian, China

Ph.D. student at Dalian University of Technology, China. Working on traffic flow forecasting.

Citation (CITATION.cff)

cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- family-names: "Lablack"
  given-names: "Mourad"
  orcid: "https://orcid.org/0000-0001-5997-7199"
keywords:
  - Traffic Prediction
  - Traffic Forecasting
  - Graph Neural Netowrk
  - Attention Mechanism
  - Multivariate Spatio-Temporal Neural Network
title: "STGM: Spatio-Temporal Graph Mixformer for Traffic Forecasting"
version: 1.0.0
doi: 10.5281/zenodo.1234
date-released: 2023-06-18
url: "https://github.com/Mouradost/STGM"
preferred-citation:
  type: article
  authors:
  - family-names: "Lablack"
    given-names: "Mourad"
    orcid: "https://orcid.org/0000-0001-5997-7199"
  - family-names: "Shen"
    given-names: "Yanming"
    orcid: "https://orcid.org/0000-0003-4108-0230"
  doi: "10.1016/j.eswa.2023.120281"
  journal: "Expert Systems with Applications"
  month: 4
  start: 120281
  end: 120292
  title: "Spatio-Temporal Graph Mixformer for Traffic Forecasting"
  volume: 228
  year: 2023

GitHub Events

Total
  • Watch event: 10
  • Fork event: 6
Last Year
  • Watch event: 10
  • Fork event: 6