sparke-diffusers
[arXiv] Official implementation of "SPARKE: Scalable Prompt-Aware Diversity Guidance in Diffusion Models via RKE Score" for enhancing diversity of diffusion models.
Science Score: 54.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
✓Academic publication links
Links to: arxiv.org -
○Committers with academic emails
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (11.0%) to scientific vocabulary
Keywords
Repository
[arXiv] Official implementation of "SPARKE: Scalable Prompt-Aware Diversity Guidance in Diffusion Models via RKE Score" for enhancing diversity of diffusion models.
Basic Info
- Host: GitHub
- Owner: mjalali
- License: apache-2.0
- Language: Python
- Default Branch: main
- Homepage: https://mjalali.github.io/SPARKE/
- Size: 8.39 MB
Statistics
- Stars: 2
- Watchers: 0
- Forks: 0
- Open Issues: 0
- Releases: 0
Topics
Metadata Files
README.md
SPARKE Diffusers: Improving the Diversity of Diffusion Models in Diffusers
SPARKE: Scalable Prompt-Aware Diversity Guidance in Diffusion Models via RKE Score
Overview
This repository contains the official implementation of SPARKE, a method for improving diversity in prompt-guided diffusion models using Scalable Prompt-Aware Diversity Guidance in Diffusion Models via RKE Score. SPARKE introduces conditional entropy-guided sampling that dynamically adapts to semantically similar prompts and supports scalable generation across modern text-to-image architectures.
Project Webpage: https://mjalali.github.io/SPARKE
Abstract
Diffusion models have demonstrated exceptional performance in high-fidelity image synthesis and prompt-based generation. However, achieving sufficient diversity—particularly within semantically similar prompts—remains a critical challenge. Prior methods use diversity metrics as guidance signals, but often neglect prompt awareness or computational scalability.
In this work, we propose SPARKE: Scalable Prompt-Aware Diversity Guidance in Diffusion Models via RKE Score. SPARKE leverages conditional entropy to guide the sampling process with respect to prompt-localized diversity. By employing Conditional Latent RKE Score Guidance, we reduce the computational complexity from $\mathcal{O}(n^3)$ to $\mathcal{O}(n)$, enabling efficient large-scale generation. We integrate SPARKE into several popular diffusion pipelines and demonstrate improved diversity without additional inference overhead.
Supported Pipelines
The following diffusers pipelines have been extended with SPARKE guidance:
| Pipeline Type | Implementation |
|------------------------------------------|---------------------------------------------------|
| Stable Diffusion v1.5 | SPARKEGuidedStableDiffusionPipeline |
| Stable Diffusion v2.1 | SPARKEGuidedStableDiffusionPipeline |
| Stable Diffusion XL | SPARKEGuidedStableDiffusionXLPipeline |
| ControlNet (SD v1.5 + OpenPose) | SPARKEGuidedStableDiffusionControlNetPipeline |
| ControlNet (SDXL + OpenPose) | SPARKEGuidedStableDiffusionXLControlNetPipeline |
| PixArt-Sigma (XL) | SPARKEGuidedPixArtSigmaPipeline |
Each pipeline supports both entropy-based and kernel-based guidance (e.g., Vendi, RKE, Conditional RKE) in a prompt-aware and scalable fashion.
Installation
- Clone this repository:
bash git clone https://github.com/mjalali/sparke-diffusers.git cd sparke-diffusers/sparke_diffusers pip install -r requirements.txt
Usage
You can directly import and use the SPARKE-enabled pipelines:
```python
pipe = getdiffusionpipeline(name='sdxl')
image = pipe( prompt="a photorealistic portrait of a man with freckles", guidancescale=7.5, criteria='vscoreclip', algorithm='cond-rke', criteriaguidancescale=0.4, numinferencesteps=50, kernel='gaussian', sigmaimage=0.8, sigmatext=0.35, guidancefreq=10, uselatentsforguidance=True, regularize=False, regions_list=['face'], ).images[0]
image.save("output.jpg") ```
Bibtex Citation
To cite this work, please use the following BibTeX entries:
SPARKE Diversity Guidance:
bibtex
@article{jalali2025sparke,
author = {Mohammad Jalali and Haoyu Lei and Amin Gohari and Farzan Farnia},
title = {SPARKE: Scalable Prompt-Aware Diversity Guidance in Diffusion Models via RKE Score},
journal = {arXiv preprint arXiv:2506.10173},
year = {2025},
url = {https://arxiv.org/abs/2506.10173},
}
RKE Score:
bibtex
@inproceedings{jalali2023rke,
author = {Jalali, Mohammad and Li, Cheuk Ting and Farnia, Farzan},
booktitle = {Advances in Neural Information Processing Systems},
pages = {9931--9943},
title = {An Information-Theoretic Evaluation of Generative Models in Learning Multi-modal Distributions},
url = {https://openreview.net/forum?id=PdZhf6PiAb},
volume = {36},
year = {2023}
}
Owner
- Name: Mohammad Jalali
- Login: mjalali
- Kind: user
- Website: https://mjalali.github.io/
- Repositories: 1
- Profile: https://github.com/mjalali
Citation (CITATION.cff)
cff-version: 1.2.0
title: 'Diffusers: State-of-the-art diffusion models'
message: >-
If you use this software, please cite it using the
metadata from this file.
type: software
authors:
- given-names: Patrick
family-names: von Platen
- given-names: Suraj
family-names: Patil
- given-names: Anton
family-names: Lozhkov
- given-names: Pedro
family-names: Cuenca
- given-names: Nathan
family-names: Lambert
- given-names: Kashif
family-names: Rasul
- given-names: Mishig
family-names: Davaadorj
- given-names: Dhruv
family-names: Nair
- given-names: Sayak
family-names: Paul
- given-names: Steven
family-names: Liu
- given-names: William
family-names: Berman
- given-names: Yiyi
family-names: Xu
- given-names: Thomas
family-names: Wolf
repository-code: 'https://github.com/huggingface/diffusers'
abstract: >-
Diffusers provides pretrained diffusion models across
multiple modalities, such as vision and audio, and serves
as a modular toolbox for inference and training of
diffusion models.
keywords:
- deep-learning
- pytorch
- image-generation
- hacktoberfest
- diffusion
- text2image
- image2image
- score-based-generative-modeling
- stable-diffusion
- stable-diffusion-diffusers
license: Apache-2.0
version: 0.12.1
GitHub Events
Total
- Push event: 2
Last Year
- Push event: 2
Committers
Last synced: 8 months ago
Top Committers
| Name | Commits | |
|---|---|---|
| Mohammad Jalali | m****9@g****m | 11 |
| Mohammad Jalali | m****i@M****l | 2 |
Issues and Pull Requests
Last synced: 8 months ago
All Time
- Total issues: 0
- Total pull requests: 0
- Average time to close issues: N/A
- Average time to close pull requests: N/A
- Total issue authors: 0
- Total pull request authors: 0
- Average comments per issue: 0
- Average comments per pull request: 0
- Merged pull requests: 0
- Bot issues: 0
- Bot pull requests: 0
Past Year
- Issues: 0
- Pull requests: 0
- Average time to close issues: N/A
- Average time to close pull requests: N/A
- Issue authors: 0
- Pull request authors: 0
- Average comments per issue: 0
- Average comments per pull request: 0
- Merged pull requests: 0
- Bot issues: 0
- Bot pull requests: 0
Top Authors
Issue Authors
Pull Request Authors
Top Labels
Issue Labels
Pull Request Labels
Dependencies
- ubuntu 20.04 build
- ubuntu 20.04 build
- ubuntu 20.04 build
- ubuntu 20.04 build
- nvidia/cuda 12.1.0-runtime-ubuntu20.04 build
- nvidia/cuda 12.1.0-runtime-ubuntu20.04 build
- ubuntu 20.04 build
- nvidia/cuda 12.1.0-runtime-ubuntu20.04 build
- nvidia/cuda 12.1.0-runtime-ubuntu20.04 build
- nvidia/cuda 12.1.0-runtime-ubuntu20.04 build
- ubuntu 20.04 build
- ubuntu 20.04 build
- ubuntu 20.04 build
- ubuntu 20.04 build
- nvidia/cuda 12.1.0-runtime-ubuntu20.04 build
- nvidia/cuda 12.1.0-runtime-ubuntu20.04 build
- ubuntu 20.04 build
- nvidia/cuda 12.1.0-runtime-ubuntu20.04 build
- nvidia/cuda 12.1.0-runtime-ubuntu20.04 build
- nvidia/cuda 12.1.0-runtime-ubuntu20.04 build
- Jinja2 *
- accelerate >=0.16.0
- ftfy *
- peft ==0.7.0
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- accelerate >=0.31.0
- ftfy *
- peft >=0.11.1
- sentencepiece *
- tensorboard *
- torchvision *
- transformers >=4.41.2
- Jinja2 *
- accelerate >=0.31.0
- decord >=0.6.0
- ftfy *
- imageio-ffmpeg *
- peft >=0.11.1
- sentencepiece *
- tensorboard *
- torchvision *
- transformers >=4.41.2
- Jinja2 *
- accelerate >=0.16.0
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- webdataset *
- accelerate >=0.16.0
- datasets *
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- datasets *
- flax *
- ftfy *
- optax *
- tensorboard *
- torch *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- SentencePiece *
- accelerate >=0.16.0
- datasets *
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- wandb *
- Jinja2 *
- accelerate >=0.16.0
- datasets *
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- wandb *
- Jinja2 *
- accelerate >=0.16.0
- datasets *
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- wandb *
- Jinja2 *
- accelerate *
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- accelerate >=0.16.0
- ftfy *
- peft ==0.7.0
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- flax *
- ftfy *
- optax *
- tensorboard *
- torch *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- accelerate >=0.31.0
- ftfy *
- peft >=0.11.1
- sentencepiece *
- tensorboard *
- torchvision *
- transformers >=4.41.2
- Jinja2 *
- accelerate >=1.0.0
- ftfy *
- peft >=0.14.0
- sentencepiece *
- tensorboard *
- torchvision *
- transformers >=4.47.0
- Jinja2 *
- accelerate >=0.31.0
- ftfy *
- peft ==0.11.1
- sentencepiece *
- tensorboard *
- torchvision *
- transformers >=4.41.2
- Jinja2 *
- accelerate >=0.16.0
- ftfy *
- peft ==0.7.0
- tensorboard *
- torchvision *
- transformers >=4.25.1
- accelerate ==1.2.0
- peft >=0.14.0
- torch *
- torchvision *
- transformers ==4.47.0
- wandb *
- accelerate >=0.16.0
- datasets *
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- accelerate >=0.16.0
- datasets *
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- huggingface-hub >=0.26.2
- Pillow *
- accelerate >=0.16.0
- bitsandbytes *
- datasets *
- huggingface_hub *
- lpips *
- numpy *
- packaging *
- taming_transformers *
- torch *
- torchvision *
- tqdm *
- transformers *
- wandb *
- xformers *
- Jinja2 *
- diffusers *
- ftfy *
- tensorboard *
- torch *
- torchvision *
- transformers *
- Jinja2 *
- accelerate >=0.16.0
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- accelerate >=0.16.0
- ftfy *
- peft *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- wandb *
- accelerate *
- datasets *
- peft *
- torchvision *
- transformers *
- wandb *
- webdataset *
- Jinja2 *
- accelerate >=0.16.0
- diffusers ==0.9.0
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.21.0
- Jinja2 *
- accelerate >=0.16.0
- diffusers *
- fairscale *
- ftfy *
- scipy *
- tensorboard *
- timm *
- torchvision *
- transformers >=4.25.1
- wandb *
- Jinja2 *
- accelerate >=0.16.0
- ftfy *
- intel_extension_for_pytorch >=1.13
- tensorboard *
- torchvision *
- transformers >=4.21.0
- accelerate *
- ftfy *
- modelcards *
- neural-compressor *
- tensorboard *
- torchvision *
- transformers >=4.25.0
- accelerate *
- ip_adapter *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- accelerate >=0.16.0
- datasets *
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- accelerate >=0.16.0
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- accelerate >=0.16.0
- datasets >=2.16.0
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- wandb >=0.16.1
- Jinja2 *
- accelerate >=0.16.0
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- flax *
- ftfy *
- optax *
- tensorboard *
- torch *
- torchvision *
- transformers >=4.25.1
- accelerate >=0.16.0
- datasets *
- ftfy *
- modelcards *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- accelerate >=0.16.0
- ftfy *
- modelcards *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- accelerate >=0.16.0
- datasets *
- tensorboard *
- torchvision *
- SentencePiece *
- controlnet-aux *
- datasets *
- torchvision *
- transformers *
- Jinja2 *
- accelerate >=0.16.0
- datasets >=2.19.1
- ftfy *
- peft ==0.7.0
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 ==3.1.5
- accelerate ==0.23.0
- diffusers ==0.20.1
- ftfy ==6.1.1
- peft ==0.5.0
- tensorboard ==2.14.0
- torch ==2.2.0
- torchvision >=0.16
- transformers ==4.38.0
- accelerate >=0.16.0
- bitsandbytes *
- deepspeed *
- peft >=0.6.0
- torchvision *
- transformers >=4.25.1
- wandb *
- aiohttp *
- fastapi *
- prometheus-fastapi-instrumentator >=7.0.0
- prometheus_client >=0.18.0
- py-consul *
- sentencepiece *
- torch *
- transformers ==4.46.1
- uvicorn *
- aiohappyeyeballs ==2.4.3
- aiohttp ==3.10.10
- aiosignal ==1.3.1
- annotated-types ==0.7.0
- anyio ==4.6.2.post1
- attrs ==24.2.0
- certifi ==2024.8.30
- charset-normalizer ==3.4.0
- click ==8.1.7
- fastapi ==0.115.3
- filelock ==3.16.1
- frozenlist ==1.5.0
- fsspec ==2024.10.0
- h11 ==0.14.0
- huggingface-hub ==0.26.1
- idna ==3.10
- jinja2 ==3.1.4
- markupsafe ==3.0.2
- mpmath ==1.3.0
- multidict ==6.1.0
- networkx ==3.4.2
- numpy ==2.1.2
- packaging ==24.1
- prometheus-client ==0.21.0
- prometheus-fastapi-instrumentator ==7.0.0
- propcache ==0.2.0
- py-consul ==1.5.3
- pydantic ==2.9.2
- pydantic-core ==2.23.4
- pyyaml ==6.0.2
- regex ==2024.9.11
- requests ==2.32.3
- safetensors ==0.4.5
- sentencepiece ==0.2.0
- sniffio ==1.3.1
- starlette ==0.41.0
- sympy ==1.13.3
- tokenizers ==0.20.1
- torch ==2.4.1
- tqdm ==4.66.5
- transformers ==4.46.1
- typing-extensions ==4.12.2
- urllib3 ==2.2.3
- uvicorn ==0.32.0
- yarl ==1.16.0
- accelerate >=0.16.0
- datasets *
- ftfy *
- safetensors *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- wandb *
- Jinja2 *
- accelerate >=0.16.0
- datasets >=2.19.1
- ftfy *
- peft ==0.7.0
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- datasets *
- flax *
- ftfy *
- optax *
- tensorboard *
- torch *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- accelerate >=0.22.0
- datasets *
- ftfy *
- peft ==0.7.0
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- accelerate >=0.16.0
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- flax *
- ftfy *
- optax *
- tensorboard *
- torch *
- torchvision *
- transformers >=4.25.1
- accelerate >=0.16.0
- datasets *
- torchvision *
- accelerate >=0.16.0
- datasets *
- numpy *
- tensorboard *
- timm *
- torchvision *
- tqdm *
- transformers >=4.25.1
- deps *
- Jinja2 *
- accelerate >=0.16.0
- ftfy *
- peft ==0.7.0
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- accelerate >=0.31.0
- ftfy *
- peft >=0.11.1
- sentencepiece *
- tensorboard *
- torchvision *
- transformers >=4.41.2
- Jinja2 *
- accelerate >=0.31.0
- decord >=0.6.0
- ftfy *
- imageio-ffmpeg *
- peft >=0.11.1
- sentencepiece *
- tensorboard *
- torchvision *
- transformers >=4.41.2
- Jinja2 *
- accelerate >=0.16.0
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- webdataset *
- accelerate >=0.16.0
- datasets *
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- datasets *
- flax *
- ftfy *
- optax *
- tensorboard *
- torch *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- SentencePiece *
- accelerate >=0.16.0
- datasets *
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- wandb *
- Jinja2 *
- accelerate >=0.16.0
- datasets *
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- wandb *
- Jinja2 *
- accelerate >=0.16.0
- datasets *
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- wandb *
- Jinja2 *
- accelerate *
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- accelerate >=0.16.0
- ftfy *
- peft ==0.7.0
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- flax *
- ftfy *
- optax *
- tensorboard *
- torch *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- accelerate >=0.31.0
- ftfy *
- peft >=0.11.1
- sentencepiece *
- tensorboard *
- torchvision *
- transformers >=4.41.2
- Jinja2 *
- accelerate >=1.0.0
- ftfy *
- peft >=0.14.0
- sentencepiece *
- tensorboard *
- torchvision *
- transformers >=4.47.0
- Jinja2 *
- accelerate >=0.31.0
- ftfy *
- peft ==0.11.1
- sentencepiece *
- tensorboard *
- torchvision *
- transformers >=4.41.2
- Jinja2 *
- accelerate >=0.16.0
- ftfy *
- peft ==0.7.0
- tensorboard *
- torchvision *
- transformers >=4.25.1
- accelerate ==1.2.0
- peft >=0.14.0
- torch *
- torchvision *
- transformers ==4.47.0
- wandb *
- accelerate >=0.16.0
- datasets *
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- accelerate >=0.16.0
- datasets *
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- huggingface-hub >=0.26.2
- Pillow *
- accelerate >=0.16.0
- bitsandbytes *
- datasets *
- huggingface_hub *
- lpips *
- numpy *
- packaging *
- taming_transformers *
- torch *
- torchvision *
- tqdm *
- transformers *
- wandb *
- xformers *
- Jinja2 *
- diffusers *
- ftfy *
- tensorboard *
- torch *
- torchvision *
- transformers *
- Jinja2 *
- accelerate >=0.16.0
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- Jinja2 *
- accelerate >=0.16.0
- ftfy *
- peft *
- tensorboard *
- torchvision *
- transformers >=4.25.1
- wandb *
- accelerate *
- datasets *
- peft *
- torchvision *
- transformers *
- wandb *
- webdataset *
- Jinja2 *
- accelerate >=0.16.0
- diffusers ==0.9.0
- ftfy *
- tensorboard *
- torchvision *
- transformers >=4.21.0