zigma

A PyTorch implementation of the paper "ZigMa: A DiT-Style Mamba-based Diffusion Model" (ECCV 2024)

https://github.com/compvis/zigma

Science Score: 54.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org, scholar.google
  • Committers with academic emails
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (9.8%) to scientific vocabulary

Keywords

diffusion-models flow-matching mamba state-space-model stochastic-interpolant zigma
Last synced: 4 months ago · JSON representation ·

Repository

A PyTorch implementation of the paper "ZigMa: A DiT-Style Mamba-based Diffusion Model" (ECCV 2024)

Basic Info
  • Host: GitHub
  • Owner: CompVis
  • License: apache-2.0
  • Language: Python
  • Default Branch: main
  • Homepage: https://taohu.me/zigma
  • Size: 30.8 MB
Statistics
  • Stars: 312
  • Watchers: 11
  • Forks: 21
  • Open Issues: 6
  • Releases: 0
Topics
diffusion-models flow-matching mamba state-space-model stochastic-interpolant zigma
Created almost 2 years ago · Last pushed 10 months ago
Metadata Files
Readme License Citation

README.md

ZigMa: A DiT-style Zigzag Mamba Diffusion Model (ECCV 2024)

ECCV 2024

Oral Talk in ICML 2024 Workshop on Long Context Foundation Models (LCFM)

This repository represents the official implementation of the paper titled "ZigMa: A DiT-style Zigzag Mamba Diffusion Model (ECCV 2024)".

Website Paper Hugging Face Model GitHub GitHub closed issues Twitter License visitors

Vincent Tao Hu, Stefan Andreas Baumann, Ming Gui, Olga Grebenkova, Pingchuan Ma, Johannes Schusterbauer, Björn Ommer

We present ZigMa, a scanning scheme that follows a zigzag pattern, considering both spatial continuity and parameter efficiency. We further adapt this scheme to video, separating the reasoning between spatial and temporal dimensions, thus achieving efficient parameter utilization. Our design allows for greater incorporation of inductive bias for non-1D data and improves parameter efficiency in diffusion models.

🎓 Citation

Please cite our paper:

bibtex @InProceedings{hu2024zigma, title={ZigMa: A DiT-style Zigzag Mamba Diffusion Model}, author={Vincent Tao Hu and Stefan Andreas Baumann and Ming Gui and Olga Grebenkova and Pingchuan Ma and Johannes Schusterbauer and Björn Ommer}, booktitle = {ECCV}, year={2024} }

:whitecheckmark: Updates

  • May. 24th, 2024: 🚀🚀🚀 New checkpoints for FacesHQ1024, landscape1024, Churches256 datasets.
  • April. 6th, 2024: Support for FP16 training, and checkpoint function, and torch.compile to achieve better memory utilization and speed boosting.
  • April. 2th, 2024: Main code released.

landscape faceshq teaser

Quick Demo

```python from model_zigma import ZigMa

imgdim = 32 inchannels = 3

model = ZigMa( inchannels=inchannels, embeddim=640, depth=18, imgdim=imgdim, patchsize=1, hastext=True, dcontext=768, ncontexttoken=77, device="cuda", scantype="zigzagN8", usepe=2, )

x = torch.rand(10, inchannels, imgdim, imgdim).to("cuda") t = torch.rand(10).to("cuda") _context = torch.rand(10, 77, 768).to("cuda") o = model(x, t, y=context) print(o.shape) ```

Improved Training Performance

In comparison to the original implementation, we implement a selection of training speed acceleration and memory saving features including gradient checkpointing | torch.compile | gradient checkpointing | training speed | memory | | :-----------: | :--------------------: | :------------: | :----: | | ❌ | ❌ | 1.05 iters/sec | 18G | | ❌ | ✔ | 0.93 steps/sec | 9G | | ✔ | ❌ | 1.8 iters/sec | 18G |

torch.compiles is for indexing operation: here and here

🚀 Training

CelebaMM256

Sweep-2, 1GPU bash accelerate launch --num_processes 1 --num_machines 1 --mixed_precision fp16 train_acc.py model=sweep2_b1 use_latent=1 data=celebamm256_uncond ckpt_every=10_000 data.sample_fid_n=5_000 data.sample_fid_bs=4 data.sample_fid_every=10_000 data.batch_size=8 note=_

Zigzag-8, 1GPU bash CUDA_VISIBLE_DEVICES=4 accelerate launch --num_processes 1 --num_machines 1 --mixed_precision fp16 --main_process_ip 127.0.0.1 --main_process_port 8868 train_acc.py model=zigzag8_b1 use_latent=1 data=celebamm256_uncond ckpt_every=10_000 data.sample_fid_n=5_000 data.sample_fid_bs=4 data.sample_fid_every=10_000 data.batch_size=4 note=_

UCF101

Baseline, multi-GPU bash CUDA_VISIBLE_DEVICES="0,1,2,3" accelerate launch --num_processes 4 --num_machines 1 --multi_gpu --mixed_precision fp16 --main_process_ip 127.0.0.1 --main_process_port 8868 train_acc.py model=3d_sweep2_b2 use_latent=1 data=ucf101 ckpt_every=10_000 data.sample_fid_n=20_0 data.sample_fid_bs=4 data.sample_fid_every=10_000 data.batch_size=4 note=_

Factorized 3D Zigzag: sst, multi-GPU bash CUDA_VISIBLE_DEVICES="0,1,2,3" accelerate launch --num_processes 4 --num_machines 1 --multi_gpu --mixed_precision fp16 --main_process_ip 127.0.0.1 --main_process_port 8868 train_acc.py model=3d_zigzag8sst_b2 use_latent=1 data=ucf101 ckpt_every=10_000 data.sample_fid_n=20_0 data.sample_fid_bs=4 data.sample_fid_every=10_000 data.batch_size=4 note=_

🚀 Sampling

FacesHQ 1024

You can directly download the model in this repository. You also can download the model in python script: ```python from huggingfacehub import hfhub_download

hfhubdownload( repoid="taohu/zigma", filename="faceshq10240090000.pt", local_dir="./checkpoints", ) ``` huggingface model repo

|Dataset | Checkingpoint|Model |data| |---|---|---|---| |faceshq1024.pt|faceshq10240090000.pt|model=s1024zigzag8b2old|data=facehq1024| |landscape1024|landscape10240210000.pt|model=s1024zigzag8b2old|data=landscapehq1024| |Churches256|churches2560280000.pt|model=zigzag8b1pe2|data=churches256| |Coco256|zigzagN8b1pe2coco14bs480400000.pt|mode=zigzag8b1pe2|data=coco14 (31.0) |

1GPU sampling bash CUDA_VISIBLE_DEVICES="2" accelerate launch --num_processes 1 --num_machines 1 sample_acc.py model=s1024_zigzag8_b2_old use_latent=1 data=facehq_1024 ckpt_every=10_000 data.sample_fid_n=5_000 data.sample_fid_bs=4 data.sample_fid_every=10_000 data.batch_size=8 sample_mode=ODE likelihood=0 num_fid_samples=5_000 sample_debug=0 ckpt=checkpoints/faceshq1024_0060000.pt The sampled images will be saved both on wandb (disable with use_wandb=False) and directory samples/

🛠️ Environment Preparation

cuda==11.8,python==3.11, torch==2.2.0, gcc==11.3(for SSM enviroment)

python=3.11 # support the torch.compile for the time being. https://github.com/pytorch/pytorch/issues/120233#issuecomment-2041472137 bash conda create -n zigma python=3.11 conda activate zigma conda install -c "nvidia/label/cuda-11.8.0" cuda-toolkit conda install pytorch torchvision pytorch-cuda=11.8 -c pytorch -c nvidia pip install torchdiffeq matplotlib h5py timm diffusers accelerate loguru blobfile ml_collections wandb pip install hydra-core opencv-python torch-fidelity webdataset einops pytorch_lightning pip install torchmetrics --upgrade pip install opencv-python causal-conv1d cd dis_causal_conv1d && pip install -e . && cd .. cd dis_mamba && pip install -e . && cd .. pip install moviepy imageio #wandb.Video() need it pip install scikit-learn --upgrade pip install transformers==4.36.2 pip install numpy-hilbert-curve # (optional) for generating the hilbert path pip install av # (optional) to use the ucf101 frame extracting pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers #for FDD metrics

Installing Mamba may cost a lot of effort. If you encounter problems, this issues in Mamba may be very helpful.

Create a file under the directory ./config/wandb/default.yaml:

yaml key: YOUR_WANDB_KEY entity: YOUR_ENTITY project: YOUR_PROJECT_NAME

Q&A

📷 Dataset Preparation

Due to privacy issue, we cannot share the dataset here, basically, we use MM-CelebA-HQ-Dataset from https://github.com/IIGROUP/MM-CelebA-HQ-Dataset, we organize into the format of webdataset to enable the scalable training in multi-gpu.

Webdataset Format: - image: image.jpg # ranging from [-1,1], shape should be [3,256,256] - latent: img_feature256.npy # latent feature for latent generation, shape should be [4,32,32]

The dataset we use include: - MM-CelebA-HQ for 256 and 512 resolution training - FacesHQ1024 for 1024 resolution - UCF101 for 16x256x256 resolution

Trend

Star History Chart

🎫 License

This work is licensed under the Apache License, Version 2.0 (as defined in the LICENSE).

By downloading and using the code and model you agree to the terms in the LICENSE.

License

Owner

  • Name: CompVis - Computer Vision and Learning LMU Munich
  • Login: CompVis
  • Kind: organization
  • Email: assist.mvl@lrz.uni-muenchen.de
  • Location: Germany

Computer Vision and Learning research group at Ludwig Maximilian University of Munich (formerly Computer Vision Group at Heidelberg University)

Citation (CITATION.cff)

@InProceedings{hu2024zigma,
      title={ZigMa: A DiT-style Zigzag Mamba Diffusion Model},
      author={Vincent Tao Hu and Stefan Andreas Baumann and Ming Gui and Olga Grebenkova and Pingchuan Ma and Johannes Fischer and Björn Ommer},
      booktitle = {ECCV},
      year={2024}
}

GitHub Events

Total
  • Issues event: 5
  • Watch event: 61
  • Issue comment event: 4
  • Push event: 5
  • Fork event: 6
Last Year
  • Issues event: 5
  • Watch event: 61
  • Issue comment event: 4
  • Push event: 5
  • Fork event: 6

Committers

Last synced: 8 months ago

All Time
  • Total Commits: 78
  • Total Committers: 2
  • Avg Commits per committer: 39.0
  • Development Distribution Score (DDS): 0.013
Past Year
  • Commits: 41
  • Committers: 1
  • Avg Commits per committer: 41.0
  • Development Distribution Score (DDS): 0.0
Top Committers
Name Email Commits
Tao Hu d****o 77
Stefan Baumann s****n@o****m 1

Issues and Pull Requests

Last synced: 4 months ago

All Time
  • Total issues: 26
  • Total pull requests: 1
  • Average time to close issues: 12 days
  • Average time to close pull requests: about 24 hours
  • Total issue authors: 23
  • Total pull request authors: 1
  • Average comments per issue: 1.42
  • Average comments per pull request: 0.0
  • Merged pull requests: 1
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 7
  • Pull requests: 0
  • Average time to close issues: 14 days
  • Average time to close pull requests: N/A
  • Issue authors: 6
  • Pull request authors: 0
  • Average comments per issue: 1.0
  • Average comments per pull request: 0
  • Merged pull requests: 0
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • huangjch526 (3)
  • UWong-cmyk (2)
  • lqniunjunlper (1)
  • Mr-Harry (1)
  • rutuja1409 (1)
  • DonMuv (1)
  • EndingCredits (1)
  • shengzhang90 (1)
  • 66ling66 (1)
  • Yaziwel (1)
  • Hiccupwzy (1)
  • xiaoxuesheng180 (1)
  • yyNoBug (1)
  • nemonameless (1)
  • awesomeNabi (1)
Pull Request Authors
  • stefan-baumann (2)
Top Labels
Issue Labels
Pull Request Labels