https://github.com/kuleshov-group/bd3lms

Block Diffusion: Interpolating Between Autoregressive and Diffusion Language Models

https://github.com/kuleshov-group/bd3lms

Science Score: 44.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
    Organization kuleshov-group has institutional domain (www.cs.cornell.edu)
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (5.7%) to scientific vocabulary
Last synced: 4 months ago · JSON representation

Repository

Block Diffusion: Interpolating Between Autoregressive and Diffusion Language Models

Basic Info
  • Host: GitHub
  • Owner: kuleshov-group
  • License: apache-2.0
  • Language: Python
  • Default Branch: main
  • Homepage: https://m-arriola.com/bd3lms/
  • Size: 1.04 MB
Statistics
  • Stars: 726
  • Watchers: 10
  • Forks: 42
  • Open Issues: 2
  • Releases: 0
Created 12 months ago · Last pushed 8 months ago
Metadata Files
Readme License

README.md

Block Diffusion: Interpolating Between Autoregressive and Diffusion Language Models (ICLR 2025 Oral)

By Marianne Arriola, Aaron Gokaslan, Justin T Chiu, Zhihan Yang, Zhixuan Qi, Jiaqi Han, Subham Sekhar Sahoo, Volodymyr Kuleshov

deploy deploy deploy

graphical_abstract

We introduce BD3-LMs, a family of Block Discrete Denoising Diffusion Language Models that achieve SOTA likelihoods among diffusion models and enable generation of arbitrary-length sequences. BD3-LMs combine the strengths of autoregressive and diffusion language models by decomposing a token sequence into blocks and performing discrete diffusion within each block. By tuning the block size, we interpolate between autoregressive and diffusion models which introduces a trade-off between quality and sample efficiency. We propose a recipe for building effective BD3-LMs that includes an efficient training algorithm, estimators of gradient variance, and data-driven noise schedules to minimize the variance.

In this repo, we provide: * The BD3-LM framework 1. Block-autoregressive likelihood parameterization 2. Data-driven noise schedules to reduce training variance 3. Arbitrary-length discrete diffusion samplers * Baseline implementations 1. Autoregressive model [AR] 2. Score Entropy Based Discrete Diffusion [SEDD] 3. Masked Diffusion Language Model [MDLM] 4. Semi-autoregressive Simplex-based Diffusion Language Model [SSD-LM] (supports sample generation only)

Code Organization

  1. main.py: Routines for training and evaluation
  2. noise_schedule.py: Noise schedules
  3. diffusion.py: Forward/reverse diffusion
  4. dataloader.py: Dataloaders
  5. utils.py: LR scheduler, logging, fsspec handling
  6. models/: Network architectures. Supports DiT and AR transformer
  7. configs/: Config files for datasets/models/noise schedules/LR schedules
  8. scripts/: Shell scripts for training/evaluation
    • train/: Training scripts (LM1B, OWT)
    • ppl/: Likelihood evaluation on the pretraining set (LM1B, OWT)
    • zs_ppl/: Zero-shot likelihood evaluation on GPT2 benchmark datasets
    • gen_ppl/: Sample quality (generative perplexity under GPT2)
    • var_len/: Arbitrary-length sequence generation
  9. ssd-lm/: SSD-LM codebase
    • run_generate_text_batch.sh: Generates SSD-LM samples
    • report_genppl.py: Reports generative perplexity of SSD-LM samples

Getting Started

To get started, create a conda environment containing the required dependencies.

bash conda create --name bd3lm python=3.9 conda activate bd3lm pip install -r requirements.txt While BD3-LMs don't require FlashAttention, evaluating baselines from MDLM require flash-attn==2.5.6

Create the following directories to store saved models and slurm logs: bash mkdir outputs watch_folder logs sample_logs and run the training as a batch job: bash sbatch scripts/train/train_owt_bd3lm.sh

Checkpoints

We have uploaded BD3-LMs trained on OpenWebText using block sizes 4, 8, 16 for 1M training steps to HuggingFace 🤗: kuleshov-group/bd3-lms BD3-LMs are finetuned from an MDLM checkpoint trained on OpenWebText for 850K gradient updates. We release the pretraining checkpoint on HuggingFace: kuleshov-group/bd3lm-owt-block_size1024-pretrain

The MDLM baseline is also found on the HuggingFace: kuleshov-group/mdlm-owt. The AR and SEDD baselines trained on OpenWebText in this Google Drive folder.

For arbitrary-length sequence generation, we compare with AR, SEDD, and MDLM (supported as an inference-only technique and does not feature a training objective), and SSD-LM. In order to generate sequences longer than the training context size (fixed at 1024 tokens for OWT), we retrained AR and MDLM from Sahoo et. al without artificially injecting BOS/EOS tokens in the context. We also provide these checkpoints on HuggingFace: kuleshov-group/mdlm-noeos-owt, kuleshov-group/sedd-noeos-owt, kuleshov-group/ar-noeos-owt.

Reproducing Experiments

Below, we describe the steps required for reproducing the experiments in the paper. Throughout, the main entry point for running experiments is the main.py script. We also provide sample slurm scripts for launching pre-training and downstream fine-tuning experiments in the scripts/ directory.

Generate Arbitrary-Length Sequences

To generate arbitrary-length sequences, set mode=sample_eval. Example scripts are provided in scripts/var_len/var_len*.sh. Here's an example script using BD3-LM:

HuggingFace model

```bash BLOCK_SIZE=4 # 4, 8, 16 LENGTH=2048 # arbitrary; needs to be a multiple of the block size

python -u main.py \ loader.evalbatchsize=1 \ model=small \ algo=bd3lm \ algo.T=5000 \ algo.backbone=hfdit \ data=openwebtext-split \ model.length=$LENGTH \ blocksize=$BLOCKSIZE \ wandb=null \ mode=sampleeval \ eval.checkpointpath=kuleshov-group/bd3lm-owt-blocksize${BLOCKSIZE} \ model.attnbackend=sdpa \ sampling.nucleusp=0.9 \ sampling.kvcache=true \ sampling.logdir=$PWD/samplelogs/samplesgenlenbd3lmblocksize${BLOCK_SIZE} ```

Local checkpoint

```bash BLOCK_SIZE=4 # 4, 8, 16 LENGTH=2048 # arbitrary; needs to be a multiple of the block size

python -u main.py \ loader.evalbatchsize=1 \ model=small \ algo=bd3lm \ algo.T=5000 \ data=openwebtext-split \ model.length=$LENGTH \ blocksize=$BLOCKSIZE \ wandb=null \ mode=sampleeval \ eval.checkpointpath=/path/to/checkpoint/bd3lm-owt-blocksize${BLOCKSIZE} \ model.attnbackend=sdpa \ sampling.nucleusp=0.9 \ sampling.kvcache=true \ sampling.logdir=$PWD/samplelogs/samplesgenlenbd3lmblocksize${BLOCKSIZE} ```

Likelihood Evaluation

To compute test perplexity, use mode=ppl_eval. Example scripts are provided in scripts/ppl/eval_owt_*.sh. Here's an example evaluation script on OpenWebText: ```bash BLOCK_SIZE=4 # 4, 8, 16

python -u main.py \ loader.evalbatchsize=16 \ model=small \ algo=bd3lm \ algo.backbone=hfdit \ data=openwebtext-split \ data.insertvalidspecial=False \ model.length=1024 \ model.attnbackend=flex \ blocksize=${BLOCKSIZE} \ eval.checkpointpath=kuleshov-group/bd3lm-owt-blocksize${BLOCKSIZE} \ wandb=null \ mode=ppleval > logs/bd3lmowtblocksize${BLOCKSIZE}.log ```

Training Pipeline

To train BD3-LMs, use mode=train (default mode). Example scripts are provided in scripts/train/train_owt*.sh. Here's an example training script on OpenWebText: ```bash BLOCKSIZE=4 # we recommend 4, 8, or 16. must be a factor of the context length PRETRAINCKPT=kuleshov-group/bd3lm-owt-block_size1024-pretrain # to train from scratch, set to null

python -u main.py \ loader.globalbatchsize=512 \ loader.evalglobalbatchsize=512 \ loader.batchsize=16 \ loader.evalbatchsize=16 \ model=small \ algo=bd3lm \ algo.clipsearchwidths=[0.5,0.6,0.7,0.8,0.9] \ data=openwebtext-split \ model.length=1024 \ blocksize=$BLOCKSIZE \ wandb.name=bd3lm-owt-blocksize${BLOCKSIZE} \ mode=train \ model.attnbackend=flex \ training.resample=True \ training.frompretrained=$PRETRAINCKPT `` The argumentsloader.batchsizeandloader.evalbatchsizeallow you to control the batch size per GPU. Ifloader.batchsize * numgpusis less than the global_batch_size, PyTorch Lightning will resort to gradient accumulation. You can also launch a training job on Slurm using the command:sbatch scripts/train/trainowtbd3lm.sh`.

Acknowledgements

This repository was built off of MDLM and SEDD.

Citation

@inproceedings{ arriola2025block, title={Block Diffusion: Interpolating Between Autoregressive and Diffusion Language Models}, author={Marianne Arriola and Aaron Gokaslan and Justin T Chiu and Zhihan Yang and Zhixuan Qi and Jiaqi Han and Subham Sekhar Sahoo and Volodymyr Kuleshov}, booktitle={The Thirteenth International Conference on Learning Representations}, year={2025}, url={https://arxiv.org/abs/2503.09573} }

Owner

  • Name: Kuleshov Group @ Cornell Tech
  • Login: kuleshov-group
  • Kind: organization

Research group at Cornell focused on machine learning, generative models, AI for science

GitHub Events

Total
  • Issues event: 57
  • Watch event: 676
  • Delete event: 3
  • Issue comment event: 50
  • Push event: 26
  • Pull request event: 10
  • Fork event: 42
  • Create event: 2
Last Year
  • Issues event: 57
  • Watch event: 676
  • Delete event: 3
  • Issue comment event: 50
  • Push event: 26
  • Pull request event: 10
  • Fork event: 42
  • Create event: 2

Issues and Pull Requests

Last synced: 6 months ago

All Time
  • Total issues: 30
  • Total pull requests: 4
  • Average time to close issues: 5 days
  • Average time to close pull requests: 2 days
  • Total issue authors: 25
  • Total pull request authors: 3
  • Average comments per issue: 0.83
  • Average comments per pull request: 0.25
  • Merged pull requests: 3
  • Bot issues: 0
  • Bot pull requests: 2
Past Year
  • Issues: 30
  • Pull requests: 4
  • Average time to close issues: 5 days
  • Average time to close pull requests: 2 days
  • Issue authors: 25
  • Pull request authors: 3
  • Average comments per issue: 0.83
  • Average comments per pull request: 0.25
  • Merged pull requests: 3
  • Bot issues: 0
  • Bot pull requests: 2
Top Authors
Issue Authors
  • Wiselnn570 (3)
  • yuecao0119 (2)
  • Lucas-Wye (2)
  • TediousBoredom (2)
  • liyanjun711 (1)
  • automenta (1)
  • weirayao (1)
  • chengshuang18 (1)
  • tauChang (1)
  • yyujinxin (1)
  • GitHUB-ZYD (1)
  • Facico (1)
  • jacobyhsi (1)
  • edenavr555 (1)
  • FWXT (1)
Pull Request Authors
  • dependabot[bot] (2)
  • s-sahoo (1)
  • mapmeld (1)
Top Labels
Issue Labels
Pull Request Labels
dependencies (2) python (2)