brushnet-public
Science Score: 54.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
✓Academic publication links
Links to: arxiv.org, scholar.google -
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (14.5%) to scientific vocabulary
Repository
Basic Info
- Host: GitHub
- Owner: vamseev
- License: other
- Language: Python
- Default Branch: main
- Size: 26.7 MB
Statistics
- Stars: 0
- Watchers: 1
- Forks: 0
- Open Issues: 1
- Releases: 0
Metadata Files
README.md
BrushNet
This repository contains the implementation of the paper "BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion"
Keywords: Image Inpainting, Diffusion Models, Image Generation
Xuan Ju12, Xian Liu12, Xintao Wang1, Yuxuan Bian2, Ying Shan1, Qiang Xu2
1ARC Lab, Tencent PCG 2The Chinese University of Hong Kong *Corresponding Author
🌐Project Page | 📜Arxiv | 🗄️Data | 📹Video | 🤗Hugging Face Demo |
📖 Table of Contents
TODO
- [x] Release trainig and inference code
- [x] Release checkpoint (sdv1.5)
- [ ] Release checkpoint (sdxl)
- [x] Release evaluation code
- [x] Release gradio demo
🛠️ Method Overview
BrushNet is a diffusion-based text-guided image inpainting model that can be plug-and-play into any pre-trained diffusion model. Our architectural design incorporates two key insights: (1) dividing the masked image features and noisy latent reduces the model's learning load, and (2) leveraging dense per-pixel control over the entire pre-trained model enhances its suitability for image inpainting tasks. More analysis can be found in the main paper.

🚀 Getting Started
Environment Requirement 🌍
BrushNet has been implemented and tested on Pytorch 1.12.1 with python 3.9.
Clone the repo:
git clone https://github.com/TencentARC/BrushNet.git
We recommend you first use conda to create virtual environment, and install pytorch following official instructions. For example:
conda create -n diffusers python=3.9 -y
conda activate diffusers
python -m pip install --upgrade pip
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116
Then, you can install diffusers (implemented in this repo) with:
pip install -e .
After that, you can install required packages thourgh:
cd examples/brushnet/
pip install -r requirements.txt
Data Download ⬇️
Dataset
You can download the BrushData and BrushBench here (as well as the EditBench we re-processed), which are used for training and testing the BrushNet. By downloading the data, you are agreeing to the terms and conditions of the license. The data structure should be like:
|-- data
|-- BrushData
|-- 00200.tar
|-- 00201.tar
|-- ...
|-- BrushDench
|-- images
|-- mapping_file.json
|-- EditBench
|-- images
|-- mapping_file.json
Noted: We only provide a part of the BrushData due to the space limit. Please write an email to juxuan.27@gmail.com if you need the full dataset.
Checkpoints
Checkpoints of BrushNet can be downloaded from here. The ckpt folder contains our pretrained checkpoints and pretrinaed Stable Diffusion checkpoint (e.g., realisticVisionV60B1v51VAE from Civitai). You can use `scripts/convertoriginalstablediffusiontodiffusers.py` to process other models downloaded from Civitai. The data structure should be like:
|-- data
|-- BrushData
|-- BrushDench
|-- EditBench
|-- ckpt
|-- realisticVisionV60B1_v51VAE
|-- model_index.json
|-- vae
|-- ...
|-- segmentation_mask_brushnet_ckpt
|-- random_mask_brushnet_ckpt
|-- ...
The checkpoint in segmentation_mask_brushnet_ckpt provides checkpoints trained on BrushData, which has segmentation prior (mask are with the same shape of objects). The random_mask_brushnet_ckpt provides a more general ckpt for random mask shape.
🏃🏼 Running Scripts
Training 🤯
You can train with segmentation mask using the script:
accelerate launch examples/brushnet/train_brushnet.py \
--pretrained_model_name_or_path runwayml/stable-diffusion-v1-5 \
--output_dir runs/logs/brushnet_segmentationmask \
--train_data_dir data/BrushData \
--resolution 512 \
--learning_rate 1e-5 \
--train_batch_size 2 \
--tracker_project_name brushnet \
--report_to tensorboard \
--resume_from_checkpoint latest \
--validation_steps 300
To use custom dataset, you can process your own data to the format of BrushData and revise --train_data_dir.
You can train with random mask using the script (by adding --random_mask):
accelerate launch examples/brushnet/train_brushnet.py \
--pretrained_model_name_or_path runwayml/stable-diffusion-v1-5 \
--output_dir runs/logs/brushnet_randommask \
--train_data_dir data/BrushData \
--resolution 512 \
--learning_rate 1e-5 \
--train_batch_size 2 \
--tracker_project_name brushnet \
--report_to tensorboard \
--resume_from_checkpoint latest \
--validation_steps 300 \
--random_mask
Inference 📜
You can inference with the script:
python examples/brushnet/test_brushnet.py
Since BrushNet is trained on Laion, it can only guarantee the performance on general scenarios. We recommend you train on your own data (e.g., product exhibition, virtual try-on) if you have high-quality industrial application requirements. We would also be appreciate if you would like to contribute your trained model!
You can also inference through gradio demo:
python examples/brushnet/app_brushnet.py
Evaluation 📏
You can evaluate using the script:
python examples/brushnet/evaluate_brushnet.py \
--brushnet_ckpt_path data/ckpt/segmentation_mask_brushnet_ckpt \
--image_save_path runs/evaluation_result/BrushBench/brushnet_segmask/inside \
--mapping_file data/BrushBench/mapping_file.json \
--base_dir data/BrushBench \
--mask_key inpainting_mask
The --mask_key indicates which kind of mask to use, inpainting_mask for inside inpainting and outpainting_mask for outside inpainting. The evaluation results (images and metrics) will be saved in --image_save_path.
Noted that you need to ignore the nsfw detector in src/diffusers/pipelines/brushnet/pipeline_brushnet.py#1261 to get the correct evaluation results. Moreover, we find different machine may generate different images, thus providing the results on our machine here.
🤝🏼 Cite Us
@misc{ju2024brushnet,
title={BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion},
author={Xuan Ju and Xian Liu and Xintao Wang and Yuxuan Bian and Ying Shan and Qiang Xu},
year={2024},
eprint={2403.06976},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
💖 Acknowledgement
Our code is modified based on diffusers, thanks to all the contributors!
Owner
- Login: vamseev
- Kind: user
- Repositories: 1
- Profile: https://github.com/vamseev
Citation (CITATION.cff)
cff-version: 1.2.0
title: 'Diffusers: State-of-the-art diffusion models'
message: >-
If you use this software, please cite it using the
metadata from this file.
type: software
authors:
- given-names: Patrick
family-names: von Platen
- given-names: Suraj
family-names: Patil
- given-names: Anton
family-names: Lozhkov
- given-names: Pedro
family-names: Cuenca
- given-names: Nathan
family-names: Lambert
- given-names: Kashif
family-names: Rasul
- given-names: Mishig
family-names: Davaadorj
- given-names: Thomas
family-names: Wolf
repository-code: 'https://github.com/huggingface/diffusers'
abstract: >-
Diffusers provides pretrained diffusion models across
multiple modalities, such as vision and audio, and serves
as a modular toolbox for inference and training of
diffusion models.
keywords:
- deep-learning
- pytorch
- image-generation
- hacktoberfest
- diffusion
- text2image
- image2image
- score-based-generative-modeling
- stable-diffusion
- stable-diffusion-diffusers
license: Apache-2.0
version: 0.12.1
GitHub Events
Total
Last Year
Dependencies
- actions/cache v2 composite
- actions/checkout v3 composite
- actions/upload-artifact v2 composite
- actions/checkout v3 composite
- docker/build-push-action v3 composite
- docker/login-action v2 composite
- slackapi/slack-github-action 6c661ce58804a1a20f6dc5fbee7f0381b469e001 composite
- ./.github/actions/setup-miniconda * composite
- actions/checkout v3 composite
- actions/upload-artifact v2 composite
- actions/checkout v3 composite
- actions/setup-python v4 composite
- actions/checkout v3 composite
- actions/setup-python v4 composite
- actions/checkout v3 composite
- actions/upload-artifact v3 composite
- actions/upload-artifact v2 composite
- actions/checkout v3 composite
- actions/setup-python v4 composite
- actions/checkout v3 composite
- actions/setup-python v4 composite
- actions/upload-artifact v2 composite
- actions/checkout v3 composite
- actions/setup-python v4 composite
- actions/checkout v3 composite
- actions/upload-artifact v2 composite
- actions/checkout v3 composite
- actions/upload-artifact v2 composite
- ./.github/actions/setup-miniconda * composite
- actions/checkout v3 composite
- actions/upload-artifact v2 composite
- actions/checkout v2 composite
- actions/setup-python v1 composite
- actions/checkout v3 composite
- crate-ci/typos v1.12.4 composite
- ubuntu 20.04 build
- ubuntu 20.04 build
- ubuntu 20.04 build
- nvidia/cuda 12.1.0-runtime-ubuntu20.04 build
- nvidia/cuda 12.1.0-runtime-ubuntu20.04 build
- ubuntu 20.04 build
- nvidia/cuda 12.1.0-runtime-ubuntu20.04 build
- nvidia/cuda 12.1.0-runtime-ubuntu20.04 build
- Jinja2 *
- accelerate >=0.16.0
- ftfy *
- peft ==0.7.0
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Pillow ==9.5.0
- accelerate ==0.20.3
- clip *
- datasets *
- ftfy *
- gradio ==3.50.0
- hpsv2 *
- image-reward *
- imgaug *
- open-clip-torch *
- opencv-python *
- segment_anything *
- tensorboard *
- torchmetrics *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- accelerate >=0.16.0
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- webdataset *
- accelerate >=0.16.0
- datasets *
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- datasets *
- flax *
- ftfy *
- optax *
- tensorboard *
- torch *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- accelerate >=0.16.0
- datasets *
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- wandb *
- Jinja2 *
- accelerate *
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- accelerate >=0.16.0
- ftfy *
- peft ==0.7.0
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- flax *
- ftfy *
- optax *
- tensorboard *
- torch *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- accelerate >=0.16.0
- ftfy *
- peft ==0.7.0
- tensorboard *
- torchvision *
- transformers >=4.25.1
- accelerate >=0.16.0
- datasets *
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- accelerate >=0.16.0
- datasets *
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- diffusers *
- ftfy *
- tensorboard *
- torch *
- torchvision *
- transformers *
- Jinja2 *
- accelerate >=0.16.0
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- accelerate >=0.16.0
- ftfy *
- peft *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- wandb *
- Jinja2 *
- accelerate >=0.16.0
- diffusers ==0.9.0
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.21.0
- Jinja2 *
- accelerate >=0.16.0
- ftfy *
- intel_extension_for_pytorch >=1.13
- tensorboard *
- torchvision *
- transformers >=4.21.0
- accelerate *
- ftfy *
- modelcards *
- neural-compressor *
- tensorboard *
- torchvision *
- transformers >=4.25.0
- Jinja2 *
- accelerate >=0.16.0
- datasets *
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- accelerate >=0.16.0
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- accelerate >=0.16.0
- datasets >=2.16.0
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- wandb >=0.16.1
- Jinja2 *
- accelerate >=0.16.0
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- flax *
- ftfy *
- optax *
- tensorboard *
- torch *
- torchvision *
- transformers >=4.25.1
- accelerate >=0.16.0
- datasets *
- ftfy *
- modelcards *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- accelerate >=0.16.0
- ftfy *
- modelcards *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- accelerate >=0.16.0
- datasets *
- tensorboard *
- torchvision *
- Jinja2 ==3.1.3
- accelerate ==0.23.0
- diffusers ==0.20.1
- ftfy ==6.1.1
- peft ==0.5.0
- tensorboard ==2.14.0
- torch ==2.0.1
- torchvision >=0.16
- transformers ==4.36.0
- accelerate >=0.16.0
- datasets *
- ftfy *
- safetensors *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- wandb *
- Jinja2 *
- accelerate >=0.16.0
- datasets *
- ftfy *
- peft ==0.7.0
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- datasets *
- flax *
- ftfy *
- optax *
- tensorboard *
- torch *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- accelerate >=0.22.0
- datasets *
- ftfy *
- peft ==0.7.0
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- accelerate >=0.16.0
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- flax *
- ftfy *
- optax *
- tensorboard *
- torch *
- torchvision *
- transformers >=4.25.1
- accelerate >=0.16.0
- datasets *
- torchvision *
- accelerate >=0.16.0
- bitsandbytes *
- deepspeed *
- peft >=0.6.0
- torchvision *
- transformers >=4.25.1
- wandb *
- deps *