https://github.com/bytedance/hybrid-sd
Science Score: 23.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
○CITATION.cff file
-
✓codemeta.json file
Found codemeta.json file -
○.zenodo.json file
-
○DOI references
-
✓Academic publication links
Links to: arxiv.org -
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (12.0%) to scientific vocabulary
Keywords
Repository
Basic Info
Statistics
- Stars: 19
- Watchers: 4
- Forks: 0
- Open Issues: 1
- Releases: 0
Topics
Metadata Files
README.md
Introduction
Hybrid SD is a novel framework designed for edge-cloud collaborative inference of Stable Diffusion Models. By integrating the superior large models on cloud servers and efficient small models on edge devices, Hybrid SD achieves state-of-the-art parameter efficiency on edge devices with competitive visual quality.
Installation
bash
conda create -n hybrid_sd python=3.9.2
conda activate hybrid_sd
pip install -r requirements.txt
Pretrained Models
We provide a number of pretrained models as follows: - Ours pruned U-Net (224M): hybrid-sd-224m - Ours tiny VAE: hybrid-sd-tinyvae and SDXL version: hybrid-sd-tinyvae-xl. Additionaly, we provide the decoder pruned version (speed up 20%+) of SD1.5 hybrid-sd-small-vae and the SDXL hybrid-sd-small-vae-xl. Visual results can be found on Results. - SD-v1.4 and Ours pruned LCM (224M) hybrid-sd-v1-4-lcm-224
Hybrid Inference
SD Models
To use hybrid SD for inference, you can launch the scripts/hybrid_sd/hybird_sd.sh, please specify the large and small models. For hybrid inference for SDXL models, please refer to scripts/hybrid_sd/hybird_sdxl.sh accordingly.
Optional arguments
PATH_MODEL_LARGE: the large model path.PATH_MODEL_SMALL: the small model path.--step: the steps distributed to different models. (e.g., "10,15" means the first 10 steps are distributed to the large model, while the last 15 steps are shifted to the small model.)--seed: the random seed.--img_sz: the image size.--prompts_file: put prompts in the .txt file.--output_dir: the output directory for saving generated images.
Latent Consistency Models (LCMs)
To use hybrid SD for LCMs, you can launch the scripts/hybrid_sd/hybird_lcm.sh and specify the large model and small model. You also need to pass TEACHER_MODEL_PATH to load VAE, tokenizer, and Text Encoder.
Evaluation on MS-COCO Benchmark
Evaluate hybrid inference with SD Models on MS-COCO 2014 30K.
bash bash scripts/hybrid_sd/generate_dpm_eval.shEvaluate hybrid inference with LCMs on MS-COCO 2014 30K.
bash bash scripts/hybrid_sd/generate_lcm_eval.sh
Training
Pruning U-Net
```bash
pruning U-Net through significance score.
bash scripts/prunesd/prunetiny.sh
finetuning the pruned U-Net.
bash scripts/prunesd/kdfinetunetiny.sh ``` Following BK-SDM, we use the dataset preprocessed212k.
Training our lightweight VAE
bash
bash scripts/optimize_vae/train_tinyvae.sh
Note
- We use datasets from [Laion_aesthetics_5plus_1024_33M](https://huggingface.co/datasets/MuhammadHanif/Laion_aesthetics_5plus_1024_33M). - We optimize VAE with LPIPS loss and adversarial loss. - We adopt the discriminator from StyelGAN-t along with several data augmentation and degradation techniques for VAE enhancement.Training LCMs
Training accelerated Latent consistency models (LCM) using the following scripts.
- Distilling SD models to LCMs
bash
bash scripts/hybrid_sd/lcm_t2i_sd.sh
- Distilling Pruned SD models to LCMs
bash
bash scripts/hybrid_sd/lcm_t2i_tiny.sh
Results
Hybrid SDXL Inference
VAEs
Our tiny VAE vs. TAESD
Ours VAE shows better visual quality and detail refinements than TAESD. Ours VAE also achieves better FID scores than TAESD on MSCOCO 2017 5K datasets.
Our small VAE vs. Baseline
| Model (fp16)| Latency on V100 (ms) | GPU Memory (MiB)| |---|:---:|:---:| |SDXL baseline vae|802.7|19203| |SDXL small vae (Ours)|611.8|17469| |SDXL tiny vae (Ours)|61.1|8017| |SD1.5 baseline vae|186.6|12987| |SD1.5 small vae (Ours)|135.6|9087| |SD1.5 tiny vae (Ours)|16.4|6929|
Acknowledgments
- CompVis, Runway, and Stability AI for the pioneering research on Stable Diffusion.
- Diffusers, BK-SDM, TAESD for their valuable contributions.
Citation
If you find our work helpful, please cite it!
@article{yan2024hybrid,
title={Hybrid SD: Edge-Cloud Collaborative Inference for Stable Diffusion Models},
author={Yan, Chenqian and Liu, Songwei and Liu, Hongjian and Peng, Xurui and Wang, Xiaojian and Chen, Fangming and Fu, Lean and Mei, Xing},
journal={arXiv preprint arXiv:2408.06646},
year={2024}
}
License
This project is licensed under the Apache-2.0 License.
Owner
- Name: Bytedance Inc.
- Login: bytedance
- Kind: organization
- Location: Singapore
- Website: https://opensource.bytedance.com
- Twitter: ByteDanceOSS
- Repositories: 255
- Profile: https://github.com/bytedance
GitHub Events
Total
- Issues event: 5
- Watch event: 29
- Issue comment event: 1
- Push event: 3
- Public event: 1
- Fork event: 1
Last Year
- Issues event: 5
- Watch event: 29
- Issue comment event: 1
- Push event: 3
- Public event: 1
- Fork event: 1
Issues and Pull Requests
Last synced: 10 months ago
All Time
- Total issues: 2
- Total pull requests: 0
- Average time to close issues: N/A
- Average time to close pull requests: N/A
- Total issue authors: 2
- Total pull request authors: 0
- Average comments per issue: 0.0
- Average comments per pull request: 0
- Merged pull requests: 0
- Bot issues: 0
- Bot pull requests: 0
Past Year
- Issues: 2
- Pull requests: 0
- Average time to close issues: N/A
- Average time to close pull requests: N/A
- Issue authors: 2
- Pull request authors: 0
- Average comments per issue: 0.0
- Average comments per pull request: 0
- Merged pull requests: 0
- Bot issues: 0
- Bot pull requests: 0
Top Authors
Issue Authors
- wuxing610 (1)
- flynnamy (1)
- DRJYYDS (1)
Pull Request Authors
Top Labels
Issue Labels
Pull Request Labels
Dependencies
- click ==8.1.3
- dill ==0.3.6
- glfw ==2.2.0
- imageio ==2.9.0
- imageio-ffmpeg ==0.4.3
- matplotlib ==3.4.2
- ninja ==1.11.1
- numpy ==1.24.1
- open-clip-torch ==2.10.1
- pillow ==8.3.1
- psutil ==5.4.7
- pyopengl ==3.1.5
- requests ==2.26.0
- scipy ==1.10.0
- setuptools ==58.0.4
- tensorboard ==2.11.2
- timm ==0.6.12
- tqdm ==4.62.2
- webdataset ==0.2.31
- accelerate *
- addict *
- clean-fid *
- datasets ==2.18.0
- decorator *
- diffusers ==0.27.0
- einops *
- matplotlib *
- open-clip-torch ==2.24.0
- peft ==0.2.0
- prefetch_generator *
- protobuf ==3.20.3
- pygments *
- pynvml *
- pytorch-fid ==0.3.0
- thop ==0.1.1.post2209072238
- timm *
- torch ==2.1.1
- torch-fidelity ==0.3.0
- torchvision *
- tqdm *
- transformers ==4.27.4
- wandb *
- wcwidth *
- webdataset *
- xformers ==0.0.23
- yapf *