brushnet

[ECCV 2024] The official implementation of paper "BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion"

https://github.com/tencentarc/brushnet

Science Score: 64.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org, scholar.google
  • Committers with academic emails
    1 of 4 committers (25.0%) from academic institutions
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (11.2%) to scientific vocabulary

Keywords

diffusion diffusion-models eccv eccv2024 image-inpainting text-to-image
Last synced: 4 months ago · JSON representation ·

Repository

[ECCV 2024] The official implementation of paper "BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion"

Basic Info
Statistics
  • Stars: 1,603
  • Watchers: 32
  • Forks: 134
  • Open Issues: 53
  • Releases: 0
Topics
diffusion diffusion-models eccv eccv2024 image-inpainting text-to-image
Created almost 2 years ago · Last pushed about 1 year ago
Metadata Files
Readme Contributing License Code of conduct Citation

README.md

BrushNet

This repository contains the implementation of the ECCV2024 paper "BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion"

Keywords: Image Inpainting, Diffusion Models, Image Generation

Xuan Ju12, Xian Liu12, Xintao Wang1, Yuxuan Bian2, Ying Shan1, Qiang Xu2
1ARC Lab, Tencent PCG 2The Chinese University of Hong Kong *Corresponding Author

Project Page | Arxiv | Data | Video | Hugging Face Demo |

** Table of Contents**

Update Log

  • [2024/12/17] BrushEdit are released, an efficient, white-box, free-form image editing tool powered by LLM-agents and an all-in-one inpainting model.
  • [2024/12/17] BrushNetX (Stronger BrushNet) models are released.

TODO

  • [x] Release trainig and inference code
  • [x] Release checkpoint (sdv1.5)
  • [x] Release checkpoint (sdxl). Sadly, I only have V100 for training this checkpoint, which can only train with a batch size of 1 with a slow speed. The current ckpt is only trained for a small step number thus perform not well. But fortunately, yuanhang volunteer to help training a better version. Please stay tuned! Thank yuanhang for his effort!
  • [x] Release evluation code
  • [x] Release gradio demo
  • [x] Release comfyui demo. Thank nullquant (ConfyUI-BrushNet) and kijai (ComfyUI-BrushNet-Wrapper) for helping!
  • [x] Release trainig data. Thank random123123 for helping!
  • [x] We use BrushNet to participate in CVPR2024 GenAI Media Generation Challenge Workshop and get top prize! The solution is provided in InstructionGuidedEditing
  • [x] Release a new version of checkpoint (sdxl).

Method Overview

BrushNet is a diffusion-based text-guided image inpainting model that can be plug-and-play into any pre-trained diffusion model. Our architectural design incorporates two key insights: (1) dividing the masked image features and noisy latent reduces the model's learning load, and (2) leveraging dense per-pixel control over the entire pre-trained model enhances its suitability for image inpainting tasks. More analysis can be found in the main paper.

Getting Started

Environment Requirement

BrushNet has been implemented and tested on Pytorch 1.12.1 with python 3.9.

Clone the repo:

git clone https://github.com/TencentARC/BrushNet.git

We recommend you first use conda to create virtual environment, and install pytorch following official instructions. For example:

conda create -n diffusers python=3.9 -y conda activate diffusers python -m pip install --upgrade pip pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116

Then, you can install diffusers (implemented in this repo) with:

pip install -e .

After that, you can install required packages thourgh:

cd examples/brushnet/ pip install -r requirements.txt

Data Download

Dataset

You can download the BrushData and BrushBench here (as well as the EditBench we re-processed), which are used for training and testing the BrushNet. By downloading the data, you are agreeing to the terms and conditions of the license. The data structure should be like:

|-- data |-- BrushData |-- 00200.tar |-- 00201.tar |-- ... |-- BrushDench |-- images |-- mapping_file.json |-- EditBench |-- images |-- mapping_file.json

Noted: We only provide a part of the BrushData in google drive due to the space limit. random123123 has helped upload a full dataset on hugging face here. Thank for his help!

Checkpoints

Checkpoints of BrushNet can be downloaded from here. The ckpt folder contains

  • BrushNet pretrained checkpoints for Stable Diffusion v1.5 (segmentation_mask_brushnet_ckpt and random_mask_brushnet_ckpt)
  • pretrinaed Stable Diffusion v1.5 checkpoint (e.g., realisticVisionV60B1v51VAE from Civitai). You can use `scripts/convertoriginalstablediffusiontodiffusers.py` to process other models downloaded from Civitai.
  • BrushNet pretrained checkpoints for Stable Diffusion XL (segmentation_mask_brushnet_ckpt_sdxl_v1 and random_mask_brushnet_ckpt_sdxl_v0). A better version will be shortly released by yuanhang. Please stay tuned!
  • pretrinaed Stable Diffusion XL checkpoint (e.g., juggernautXLjuggernautX from Civitai). You can use `StableDiffusionXLPipeline.fromsinglefile("path of safetensors").savepretrained("path to save",safe_serialization=False)` to process other models downloaded from Civitai.

The data structure should be like:

|-- data |-- BrushData |-- BrushDench |-- EditBench |-- ckpt |-- realisticVisionV60B1_v51VAE |-- model_index.json |-- vae |-- ... |-- segmentation_mask_brushnet_ckpt |-- segmentation_mask_brushnet_ckpt_sdxl_v0 |-- random_mask_brushnet_ckpt |-- random_mask_brushnet_ckpt_sdxl_v0 |-- ...

The checkpoint in segmentation_mask_brushnet_ckpt and segmentation_mask_brushnet_ckpt_sdxl_v0 provide checkpoints trained on BrushData, which has segmentation prior (mask are with the same shape of objects). The random_mask_brushnet_ckpt and random_mask_brushnet_ckpt_sdxl provide a more general ckpt for random mask shape.

Running Scripts

Training

You can train with segmentation mask using the script:

```

sd v1.5

accelerate launch examples/brushnet/trainbrushnet.py \ --pretrainedmodelnameorpath runwayml/stable-diffusion-v1-5 \ --outputdir runs/logs/brushnetsegmentationmask \ --traindatadir data/BrushData \ --resolution 512 \ --learningrate 1e-5 \ --trainbatchsize 2 \ --trackerprojectname brushnet \ --reportto tensorboard \ --resumefromcheckpoint latest \ --validationsteps 300 --checkpointing_steps 10000

sdxl

accelerate launch examples/brushnet/trainbrushnetsdxl.py \ --pretrainedmodelnameorpath stabilityai/stable-diffusion-xl-base-1.0 \ --outputdir runs/logs/brushnetsdxlsegmentationmask \ --traindatadir data/BrushData \ --resolution 1024 \ --learningrate 1e-5 \ --trainbatchsize 1 \ --gradientaccumulationsteps 4 \ --trackerprojectname brushnet \ --reportto tensorboard \ --resumefromcheckpoint latest \ --validationsteps 300 \ --checkpointingsteps 10000 ```

To use custom dataset, you can process your own data to the format of BrushData and revise --train_data_dir.

You can train with random mask using the script (by adding --random_mask):

```

sd v1.5

accelerate launch examples/brushnet/trainbrushnet.py \ --pretrainedmodelnameorpath runwayml/stable-diffusion-v1-5 \ --outputdir runs/logs/brushnetrandommask \ --traindatadir data/BrushData \ --resolution 512 \ --learningrate 1e-5 \ --trainbatchsize 2 \ --trackerprojectname brushnet \ --reportto tensorboard \ --resumefromcheckpoint latest \ --validationsteps 300 \ --random_mask

sdxl

accelerate launch examples/brushnet/trainbrushnetsdxl.py \ --pretrainedmodelnameorpath stabilityai/stable-diffusion-xl-base-1.0 \ --outputdir runs/logs/brushnetsdxlrandommask \ --traindatadir data/BrushData \ --resolution 1024 \ --learningrate 1e-5 \ --trainbatchsize 1 \ --gradientaccumulationsteps 4 \ --trackerprojectname brushnet \ --reportto tensorboard \ --resumefromcheckpoint latest \ --validationsteps 300 \ --checkpointingsteps 10000 \ --random_mask ```

Inference

You can inference with the script:

```

sd v1.5

python examples/brushnet/test_brushnet.py

sdxl

python examples/brushnet/testbrushnetsdxl.py ```

Since BrushNet is trained on Laion, it can only guarantee the performance on general scenarios. We recommend you train on your own data (e.g., product exhibition, virtual try-on) if you have high-quality industrial application requirements. We would also be appreciate if you would like to contribute your trained model!

You can also inference through gradio demo:

```

sd v1.5

python examples/brushnet/app_brushnet.py ```

Evaluation

You can evaluate using the script:

python examples/brushnet/evaluate_brushnet.py \ --brushnet_ckpt_path data/ckpt/segmentation_mask_brushnet_ckpt \ --image_save_path runs/evaluation_result/BrushBench/brushnet_segmask/inside \ --mapping_file data/BrushBench/mapping_file.json \ --base_dir data/BrushBench \ --mask_key inpainting_mask

The --mask_key indicates which kind of mask to use, inpainting_mask for inside inpainting and outpainting_mask for outside inpainting. The evaluation results (images and metrics) will be saved in --image_save_path.

Noted that you need to ignore the nsfw detector in src/diffusers/pipelines/brushnet/pipeline_brushnet.py#1261 to get the correct evaluation results. Moreover, we find different machine may generate different images, thus providing the results on our machine here.

Cite Us

@misc{ju2024brushnet, title={BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion}, author={Xuan Ju and Xian Liu and Xintao Wang and Yuxuan Bian and Ying Shan and Qiang Xu}, year={2024}, eprint={2403.06976}, archivePrefix={arXiv}, primaryClass={cs.CV} }

Acknowledgement

Our code is modified based on diffusers, thanks to all the contributors!

Owner

  • Name: ARC Lab, Tencent PCG
  • Login: TencentARC
  • Kind: organization
  • Email: arc@tencent.com

Citation (CITATION.cff)

cff-version: 1.2.0
title: 'Diffusers: State-of-the-art diffusion models'
message: >-
  If you use this software, please cite it using the
  metadata from this file.
type: software
authors:
  - given-names: Patrick
    family-names: von Platen
  - given-names: Suraj
    family-names: Patil
  - given-names: Anton
    family-names: Lozhkov
  - given-names: Pedro
    family-names: Cuenca
  - given-names: Nathan
    family-names: Lambert
  - given-names: Kashif
    family-names: Rasul
  - given-names: Mishig
    family-names: Davaadorj
  - given-names: Thomas
    family-names: Wolf
repository-code: 'https://github.com/huggingface/diffusers'
abstract: >-
  Diffusers provides pretrained diffusion models across
  multiple modalities, such as vision and audio, and serves
  as a modular toolbox for inference and training of
  diffusion models.
keywords:
  - deep-learning
  - pytorch
  - image-generation
  - hacktoberfest
  - diffusion
  - text2image
  - image2image
  - score-based-generative-modeling
  - stable-diffusion
  - stable-diffusion-diffusers
license: Apache-2.0
version: 0.12.1

GitHub Events

Total
  • Issues event: 15
  • Watch event: 339
  • Issue comment event: 16
  • Push event: 1
  • Pull request event: 3
  • Fork event: 36
Last Year
  • Issues event: 15
  • Watch event: 339
  • Issue comment event: 16
  • Push event: 1
  • Pull request event: 3
  • Fork event: 36

Committers

Last synced: 8 months ago

All Time
  • Total Commits: 17
  • Total Committers: 4
  • Avg Commits per committer: 4.25
  • Development Distribution Score (DDS): 0.235
Past Year
  • Commits: 6
  • Committers: 3
  • Avg Commits per committer: 2.0
  • Development Distribution Score (DDS): 0.333
Top Committers
Name Email Commits
juxuan27 j****7@g****m 13
yuanhangio y****0@g****m 2
liyaowei-stu y****l@s****n 1
Ikko Eltociear Ashimine e****r@g****m 1
Committer Domains (Top 20 + Academic)

Issues and Pull Requests

Last synced: 4 months ago

All Time
  • Total issues: 77
  • Total pull requests: 5
  • Average time to close issues: 15 days
  • Average time to close pull requests: 22 days
  • Total issue authors: 66
  • Total pull request authors: 4
  • Average comments per issue: 2.23
  • Average comments per pull request: 0.2
  • Merged pull requests: 4
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 17
  • Pull requests: 1
  • Average time to close issues: 22 days
  • Average time to close pull requests: about 3 hours
  • Issue authors: 17
  • Pull request authors: 1
  • Average comments per issue: 0.47
  • Average comments per pull request: 0.0
  • Merged pull requests: 1
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • CharlesGong12 (3)
  • henbucuoshanghai (3)
  • PlutoQyl (2)
  • huangjun12 (2)
  • cs-mshah (2)
  • Shuvo001 (2)
  • yunniw2001 (2)
  • Tramac (2)
  • dydxdt (2)
  • Arcxml (2)
  • junsukha (1)
  • woshiodoman (1)
  • AGI0102 (1)
  • dengyuebj (1)
  • Chevolier (1)
Pull Request Authors
  • yuanhangio (4)
  • xduzhangjiayu (2)
  • liyaowei-stu (2)
  • eltociear (1)
Top Labels
Issue Labels
bug (20)
Pull Request Labels

Dependencies

.github/actions/setup-miniconda/action.yml actions
  • actions/cache v2 composite
.github/workflows/benchmark.yml actions
  • actions/checkout v3 composite
  • actions/upload-artifact v2 composite
.github/workflows/build_docker_images.yml actions
  • actions/checkout v3 composite
  • docker/build-push-action v3 composite
  • docker/login-action v2 composite
  • slackapi/slack-github-action 6c661ce58804a1a20f6dc5fbee7f0381b469e001 composite
.github/workflows/build_documentation.yml actions
.github/workflows/build_pr_documentation.yml actions
.github/workflows/nightly_tests.yml actions
  • ./.github/actions/setup-miniconda * composite
  • actions/checkout v3 composite
  • actions/upload-artifact v2 composite
.github/workflows/pr_dependency_test.yml actions
  • actions/checkout v3 composite
  • actions/setup-python v4 composite
.github/workflows/pr_flax_dependency_test.yml actions
  • actions/checkout v3 composite
  • actions/setup-python v4 composite
.github/workflows/pr_test_fetcher.yml actions
  • actions/checkout v3 composite
  • actions/upload-artifact v3 composite
  • actions/upload-artifact v2 composite
.github/workflows/pr_test_peft_backend.yml actions
  • actions/checkout v3 composite
  • actions/setup-python v4 composite
.github/workflows/pr_tests.yml actions
  • actions/checkout v3 composite
  • actions/setup-python v4 composite
  • actions/upload-artifact v2 composite
.github/workflows/pr_torch_dependency_test.yml actions
  • actions/checkout v3 composite
  • actions/setup-python v4 composite
.github/workflows/push_tests.yml actions
  • actions/checkout v3 composite
  • actions/upload-artifact v2 composite
.github/workflows/push_tests_fast.yml actions
  • actions/checkout v3 composite
  • actions/upload-artifact v2 composite
.github/workflows/push_tests_mps.yml actions
  • ./.github/actions/setup-miniconda * composite
  • actions/checkout v3 composite
  • actions/upload-artifact v2 composite
.github/workflows/stale.yml actions
  • actions/checkout v2 composite
  • actions/setup-python v1 composite
.github/workflows/typos.yml actions
  • actions/checkout v3 composite
  • crate-ci/typos v1.12.4 composite
.github/workflows/upload_pr_documentation.yml actions
docker/diffusers-flax-cpu/Dockerfile docker
  • ubuntu 20.04 build
docker/diffusers-flax-tpu/Dockerfile docker
  • ubuntu 20.04 build
docker/diffusers-onnxruntime-cpu/Dockerfile docker
  • ubuntu 20.04 build
docker/diffusers-onnxruntime-cuda/Dockerfile docker
  • nvidia/cuda 12.1.0-runtime-ubuntu20.04 build
docker/diffusers-pytorch-compile-cuda/Dockerfile docker
  • nvidia/cuda 12.1.0-runtime-ubuntu20.04 build
docker/diffusers-pytorch-cpu/Dockerfile docker
  • ubuntu 20.04 build
docker/diffusers-pytorch-cuda/Dockerfile docker
  • nvidia/cuda 12.1.0-runtime-ubuntu20.04 build
docker/diffusers-pytorch-xformers-cuda/Dockerfile docker
  • nvidia/cuda 12.1.0-runtime-ubuntu20.04 build
examples/advanced_diffusion_training/requirements.txt pypi
  • Jinja2 *
  • accelerate >=0.16.0
  • ftfy *
  • peft ==0.7.0
  • tensorboard *
  • torchvision *
  • transformers >=4.25.1
examples/brushnet/requirements.txt pypi
  • Pillow ==9.5.0
  • accelerate ==0.20.3
  • datasets *
  • ftfy *
  • hpsv2 *
  • image-reward *
  • imgaug *
  • open-clip-torch *
  • opencv-python *
  • tensorboard *
  • torchmetrics *
  • torchvision *
  • transformers >=4.25.1
examples/consistency_distillation/requirements.txt pypi
  • Jinja2 *
  • accelerate >=0.16.0
  • ftfy *
  • tensorboard *
  • torchvision *
  • transformers >=4.25.1
  • webdataset *
examples/controlnet/requirements.txt pypi
  • accelerate >=0.16.0
  • datasets *
  • ftfy *
  • tensorboard *
  • torchvision *
  • transformers >=4.25.1
examples/controlnet/requirements_flax.txt pypi
  • Jinja2 *
  • datasets *
  • flax *
  • ftfy *
  • optax *
  • tensorboard *
  • torch *
  • torchvision *
  • transformers >=4.25.1
examples/controlnet/requirements_sdxl.txt pypi
  • Jinja2 *
  • accelerate >=0.16.0
  • datasets *
  • ftfy *
  • tensorboard *
  • torchvision *
  • transformers >=4.25.1
  • wandb *
examples/custom_diffusion/requirements.txt pypi
  • Jinja2 *
  • accelerate *
  • ftfy *
  • tensorboard *
  • torchvision *
  • transformers >=4.25.1
examples/dreambooth/requirements.txt pypi
  • Jinja2 *
  • accelerate >=0.16.0
  • ftfy *
  • peft ==0.7.0
  • tensorboard *
  • torchvision *
  • transformers >=4.25.1
examples/dreambooth/requirements_flax.txt pypi
  • Jinja2 *
  • flax *
  • ftfy *
  • optax *
  • tensorboard *
  • torch *
  • torchvision *
  • transformers >=4.25.1
examples/dreambooth/requirements_sdxl.txt pypi
  • Jinja2 *
  • accelerate >=0.16.0
  • ftfy *
  • peft ==0.7.0
  • tensorboard *
  • torchvision *
  • transformers >=4.25.1
examples/instruct_pix2pix/requirements.txt pypi
  • accelerate >=0.16.0
  • datasets *
  • ftfy *
  • tensorboard *
  • torchvision *
  • transformers >=4.25.1
examples/kandinsky2_2/text_to_image/requirements.txt pypi
  • Jinja2 *
  • accelerate >=0.16.0
  • datasets *
  • ftfy *
  • tensorboard *
  • torchvision *
  • transformers >=4.25.1
examples/research_projects/colossalai/requirement.txt pypi
  • Jinja2 *
  • diffusers *
  • ftfy *
  • tensorboard *
  • torch *
  • torchvision *
  • transformers *
examples/research_projects/consistency_training/requirements.txt pypi
  • Jinja2 *
  • accelerate >=0.16.0
  • ftfy *
  • tensorboard *
  • torchvision *
  • transformers >=4.25.1
examples/research_projects/diffusion_dpo/requirements.txt pypi
  • Jinja2 *
  • accelerate >=0.16.0
  • ftfy *
  • peft *
  • tensorboard *
  • torchvision *
  • transformers >=4.25.1
  • wandb *
examples/research_projects/dreambooth_inpaint/requirements.txt pypi
  • Jinja2 *
  • accelerate >=0.16.0
  • diffusers ==0.9.0
  • ftfy *
  • tensorboard *
  • torchvision *
  • transformers >=4.21.0
examples/research_projects/intel_opts/textual_inversion/requirements.txt pypi
  • Jinja2 *
  • accelerate >=0.16.0
  • ftfy *
  • intel_extension_for_pytorch >=1.13
  • tensorboard *
  • torchvision *
  • transformers >=4.21.0
examples/research_projects/intel_opts/textual_inversion_dfq/requirements.txt pypi
  • accelerate *
  • ftfy *
  • modelcards *
  • neural-compressor *
  • tensorboard *
  • torchvision *
  • transformers >=4.25.0
examples/research_projects/lora/requirements.txt pypi
  • Jinja2 *
  • accelerate >=0.16.0
  • datasets *
  • ftfy *
  • tensorboard *
  • torchvision *
  • transformers >=4.25.1
examples/research_projects/multi_subject_dreambooth/requirements.txt pypi
  • Jinja2 *
  • accelerate >=0.16.0
  • ftfy *
  • tensorboard *
  • torchvision *
  • transformers >=4.25.1
examples/research_projects/multi_subject_dreambooth_inpainting/requirements.txt pypi
  • Jinja2 *
  • accelerate >=0.16.0
  • datasets >=2.16.0
  • ftfy *
  • tensorboard *
  • torchvision *
  • transformers >=4.25.1
  • wandb >=0.16.1
examples/research_projects/multi_token_textual_inversion/requirements.txt pypi
  • Jinja2 *
  • accelerate >=0.16.0
  • ftfy *
  • tensorboard *
  • torchvision *
  • transformers >=4.25.1
examples/research_projects/multi_token_textual_inversion/requirements_flax.txt pypi
  • Jinja2 *
  • flax *
  • ftfy *
  • optax *
  • tensorboard *
  • torch *
  • torchvision *
  • transformers >=4.25.1
examples/research_projects/onnxruntime/text_to_image/requirements.txt pypi
  • accelerate >=0.16.0
  • datasets *
  • ftfy *
  • modelcards *
  • tensorboard *
  • torchvision *
  • transformers >=4.25.1
examples/research_projects/onnxruntime/textual_inversion/requirements.txt pypi
  • accelerate >=0.16.0
  • ftfy *
  • modelcards *
  • tensorboard *
  • torchvision *
  • transformers >=4.25.1
examples/research_projects/onnxruntime/unconditional_image_generation/requirements.txt pypi
  • accelerate >=0.16.0
  • datasets *
  • tensorboard *
  • torchvision *
examples/research_projects/realfill/requirements.txt pypi
  • Jinja2 ==3.1.3
  • accelerate ==0.23.0
  • diffusers ==0.20.1
  • ftfy ==6.1.1
  • peft ==0.5.0
  • tensorboard ==2.14.0
  • torch ==2.0.1
  • torchvision >=0.16
  • transformers ==4.36.0
examples/t2i_adapter/requirements.txt pypi
  • accelerate >=0.16.0
  • datasets *
  • ftfy *
  • safetensors *
  • tensorboard *
  • torchvision *
  • transformers >=4.25.1
  • wandb *
examples/text_to_image/requirements.txt pypi
  • Jinja2 *
  • accelerate >=0.16.0
  • datasets *
  • ftfy *
  • peft ==0.7.0
  • tensorboard *
  • torchvision *
  • transformers >=4.25.1
examples/text_to_image/requirements_flax.txt pypi
  • Jinja2 *
  • datasets *
  • flax *
  • ftfy *
  • optax *
  • tensorboard *
  • torch *
  • torchvision *
  • transformers >=4.25.1
examples/text_to_image/requirements_sdxl.txt pypi
  • Jinja2 *
  • accelerate >=0.22.0
  • datasets *
  • ftfy *
  • peft ==0.7.0
  • tensorboard *
  • torchvision *
  • transformers >=4.25.1
examples/textual_inversion/requirements.txt pypi
  • Jinja2 *
  • accelerate >=0.16.0
  • ftfy *
  • tensorboard *
  • torchvision *
  • transformers >=4.25.1
examples/textual_inversion/requirements_flax.txt pypi
  • Jinja2 *
  • flax *
  • ftfy *
  • optax *
  • tensorboard *
  • torch *
  • torchvision *
  • transformers >=4.25.1
examples/unconditional_image_generation/requirements.txt pypi
  • accelerate >=0.16.0
  • datasets *
  • torchvision *
examples/wuerstchen/text_to_image/requirements.txt pypi
  • accelerate >=0.16.0
  • bitsandbytes *
  • deepspeed *
  • peft >=0.6.0
  • torchvision *
  • transformers >=4.25.1
  • wandb *
pyproject.toml pypi
setup.py pypi
  • deps *