https://github.com/compvis/tread

https://github.com/compvis/tread

Science Score: 36.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (14.1%) to scientific vocabulary
Last synced: 9 months ago · JSON representation

Repository

Basic Info
  • Host: GitHub
  • Owner: CompVis
  • License: other
  • Language: Python
  • Default Branch: master
  • Size: 6.09 MB
Statistics
  • Stars: 62
  • Watchers: 7
  • Forks: 4
  • Open Issues: 3
  • Releases: 0
Created over 1 year ago · Last pushed 10 months ago
Metadata Files
Readme License

README.md

👟TREAD: Token Routing for Efficient Architecture-agnostic Diffusion Training

Felix Krause · Timy Phan · Ming Gui · Stefan Baumann · Vincent Tao Hu · Björn Ommer

CompVis Group @ LMU Munich

Paper Project Page

This repository contains the official implementation of the paper "TREAD: Token Routing for Efficient Architecture-agnostic Diffusion Training".

We propose TREAD, a new method to increase the efficiency of diffusion training by improving upon iteration speed and performance at the same time. For this, we use uni-directional token transportation to modulate the information flow in the network.

teaser

🚀 Training

In order to train a diffusion model, we offer a minimalistic training script in train.py. In its simplest form it can be started using:

python accelerate launch train.py model=tread

or

python accelerate launch train.py model=dit

with configs/config.yaml having all the relevant information and settings for the actual training run. Please adjust this as needed before training. Note: We expect precomputed latents in this version. Under model one can decide between dit and tread which are the preconfigured versions here with the former being the standard dit and the latter being supported by TREAD. How these changes are implemented can be seen in dit.py and routing_module.py.

In our paper, we show that TREAD can also work on other architectures. In practice, one needs to be more careful with the routing process in order to adhere to the characteristics of the specific architecture as some have a spatial bias (RWKV, Mamba, etc.). For simplicity, we only provide code for the Transformer architecture as it is the most widely used while being robust and easy to work with.

🖼️ Sampling

For most experiments we use the EDM training and sampling to stay consistent with prior art, and the FID calculation is done via the ADM evaluation suite. We provide a fid.py to evaluate our models during training using the same reference batches as ADM.

💥 Guiding TREAD

TREAD works great during training! How about inference? \ It turns out TREAD can be applied during guided inference as well to gain additional performance and reduce FLOPS at the same time! \ Instead of dropping the class label (CFG), we can guide with a selection rate delta. Since TREAD's selection rate (0.5) generalizes to other rates, this can be tuned in inference-time only.

We demonstrate this in rf.py which contains minimal flow matching code for training and sampling:

sample: normal sampling\ sample_tread: TREAD sampling 🔥

🎓 Citation

If you use this codebase or otherwise found our work valuable, please cite our paper:

bibtex @article{krause2025tread, title={TREAD: Token Routing for Efficient Architecture-agnostic Diffusion Training}, author={Krause, Felix and Phan, Timy and Gui, Ming and Baumann, Stefan Andreas and Hu, Vincent Tao and Ommer, Bj{\"o}rn}, journal={arXiv preprint arXiv:2501.04765}, year={2025} }

Acknowledgements

Thanks to the open source codebases such as DiT, MaskDiT, ADM, and EDM. Our codebase is built on them.

Owner

  • Name: CompVis - Computer Vision and Learning LMU Munich
  • Login: CompVis
  • Kind: organization
  • Email: assist.mvl@lrz.uni-muenchen.de
  • Location: Germany

Computer Vision and Learning research group at Ludwig Maximilian University of Munich (formerly Computer Vision Group at Heidelberg University)

GitHub Events

Total
  • Issues event: 10
  • Watch event: 76
  • Delete event: 1
  • Issue comment event: 11
  • Member event: 1
  • Push event: 2
  • Pull request event: 1
  • Fork event: 5
  • Create event: 3
Last Year
  • Issues event: 10
  • Watch event: 76
  • Delete event: 1
  • Issue comment event: 11
  • Member event: 1
  • Push event: 2
  • Pull request event: 1
  • Fork event: 5
  • Create event: 3

Issues and Pull Requests

Last synced: 10 months ago

All Time
  • Total issues: 7
  • Total pull requests: 1
  • Average time to close issues: 3 days
  • Average time to close pull requests: 1 day
  • Total issue authors: 7
  • Total pull request authors: 1
  • Average comments per issue: 0.43
  • Average comments per pull request: 0.0
  • Merged pull requests: 1
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 7
  • Pull requests: 1
  • Average time to close issues: 3 days
  • Average time to close pull requests: 1 day
  • Issue authors: 7
  • Pull request authors: 1
  • Average comments per issue: 0.43
  • Average comments per pull request: 0.0
  • Merged pull requests: 1
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • WaitHZ (1)
  • ChuxiJ (1)
  • EtaEnding (1)
  • BitPhinix (1)
  • Schmiddo (1)
  • DogyunPark (1)
  • Oguzhanercan (1)
Pull Request Authors
  • jxiong21029 (1)
Top Labels
Issue Labels
Pull Request Labels