mdlm

[NeurIPS 2024] Simple and Effective Masked Diffusion Language Model

https://github.com/kuleshov-group/mdlm

Science Score: 62.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
    Organization kuleshov-group has institutional domain (www.cs.cornell.edu)
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (8.5%) to scientific vocabulary

Keywords

diffusion-language-models diffusion-models language-model text
Last synced: 4 months ago · JSON representation ·

Repository

[NeurIPS 2024] Simple and Effective Masked Diffusion Language Model

Basic Info
  • Host: GitHub
  • Owner: kuleshov-group
  • License: apache-2.0
  • Language: Python
  • Default Branch: master
  • Homepage: https://s-sahoo.com/mdlm/
  • Size: 108 KB
Statistics
  • Stars: 401
  • Watchers: 13
  • Forks: 54
  • Open Issues: 0
  • Releases: 0
Topics
diffusion-language-models diffusion-models language-model text
Created over 1 year ago · Last pushed 7 months ago
Metadata Files
Readme License Citation

README.md

Simple and Effective Masked Diffusion Language Models (NeurIPS 2024)

By Subham Sekhar Sahoo, Marianne Arriola, Yair Schiff, Aaron Gokaslan, Edgar Marroquin, Justin T Chiu, Alexander Rush, Volodymyr Kuleshov

arXiv Open In Colab YouTube deploy deploy Open In Studio

[Update April 14, 2025: An improved implementation is available here: DUO Github repo.]

[Update Jun 3, 2025: MDMs with KV caching: Eso-LMs Github repo.]

graphical_abstract_updated_2

We introduce MDLM, a Masked discrete Diffusion Language Model that features a novel (SUBS)titution based parameterization which simplifies the absorbing state diffusion loss to a mixture of classical masked language modeling losses. In doing so, we achieve SOTA perplexity numbers on LM1B and OpenWebText among diffusion models while achiving competitive zero-shot perplexity with SOTA AR models on numerous datasets. We provide a demo in this Open In Colab notebook or Open In Studio and a video tutorial here:

Everything Is AWESOME

In this repo, we release: * The MDLM framework. 1. SUBStitution based parameterization 2. Simplified loss calculation for masked diffusion processes * Baseline implementations [Examples]: 1. Autoregressive model that matches the SOTA AR performance on LM1B. 2. Score Entropy Based Discrete Diffusion SEDD. 3. An efficient implementation of the absorbing state D3PM that beats the previous SOTA text diffusion model SEDD on LM1B. * Samplers 1. Ancestral sampling as proposed in D3PM. 2. Analytic sampler as proposed in SEDD. 3. Our proposed efficient sampler that - makes MDLM ~3-4x faster than the existing diffusion models. [Example] - supports semi-autoregressive (SAR) generation. [Example]

Code Organization

  1. main.py: Routines for training and evaluation
  2. noise_schedule.py: Noise schedules
  3. diffusion.py: Forward/reverse diffusion
  4. dataloader.py: Dataloaders
  5. utils.py: LR scheduler, logging, fsspec handling
  6. models/: Denoising network architectures. Supports DiT, AR transformer, and Mamba
  7. configs/: Config files for datasets/denoising networks/noise schedules/LR schedules
  8. scripts/: Shell scripts for training/evaluation

Getting started in this repository

To get started, create a conda environment containing the required dependencies.

bash conda env create -f requirements.yaml conda activate mdlm

Create the following directories to store saved models and slurm logs: bash mkdir outputs mkdir watch_folder and run the training as a batch job: bash sbatch scripts/train_owt_mdlm.sh

Checkpoints

We have uploaded MDLM model trained on OpenWebText for 1M training steps to the Huggingface hub 🤗: kuleshov-group/mdlm-owt Furthermore, we have released the checkpoints for the AR and SEDD baselines trained on OpenWebText in this Google Drive folder.

Reproducing Experiments

Below, we describe the steps required for reproducing the experiments in the paper. Throughout, the main entry point for running experiments is the main.py script. We also provide sample slurm scripts for launching pre-training and downstream fine-tuning experiments in the scrips/ directory.

Generate Samples

The argument to sampling.predictor specifies the sampler which takes one of the following values: * ddpm_cache: our proposed sampler that's ~3-4x faster than the samplers propsed in D3PM and SEDD. * ddpm: Ancestral sampling proposed in D3PM. * analytic: Analytic sampler proposed in SEDD.

In the following table we report wall clock time to generate 64 samples on a single A5000 GPU with batch_size=1. $T$ denotes the time discretization of the reverse process. | | $T=5k (\downarrow)$ | $T=10k (\downarrow)$ | |-------------------------|---------------------|----------------------| | SEDD | 127.1 | 229.3 | | MDLM + ddpm | 113.8 | 206.6 | | MDLM +ddpm_cache | 40.1 | 60.4 |

To generate samples from a pre-trained model use one of the following commands:

Huggingface model

bash python main.py \ mode=sample_eval \ eval.checkpoint_path=kuleshov-group/mdlm-owt \ data=openwebtext-split \ model.length=1024 \ sampling.predictor=ddpm_cache \ sampling.steps=1000 \ loader.eval_batch_size=1 \ sampling.num_sample_batches=10 \ backbone=hf_dit

Local checkpoint

bash python main.py \ mode=sample_eval \ eval.checkpoint_path=/path/to/checkpoint/mdlm.ckpt \ data=openwebtext-split \ model.length=1024 \ sampling.predictor=ddpm_cache \ sampling.steps=10000 \ loader.eval_batch_size=1 \ sampling.num_sample_batches=1 \ backbone=dit

Semi-AR sample generation

MDLM can also generate samples of arbitrary length in a semi-autoregressive (SAR) manner. We generate 200 sequences of length 2048 tokens on a single 3090 GPU and evaluate generative perplexity under a pre-trained GPT-2 model. In the below table we find that in addition to achieving better generative perplexity, MDLM enables 25-30x faster SAR decoding relative to SSD-LM.

| | Gen. PPL ($\downarrow$) | Sec/Seq ($\downarrow$) | |---------------------|-------------------------|------------------------| | SSD-LM | 35.43 | 2473.9 | | MDLM +ddpm_cache | 27.18 | 89.3 |

Gen. PPL: Generation Perplexity, Sec/Seq: Seconds per Sequence

bash python main.py \ mode=sample_eval \ eval.checkpoint_path=kuleshov-group/mdlm-owt \ data=openwebtext-split \ parameterization=subs \ model.length=1024 \ sampling.predictor=ddpm_cache \ sampling.steps=1000 \ loader.eval_batch_size=1 \ sampling.num_sample_batches=2 \ sampling.semi_ar=True \ sampling.stride_length=512 \ sampling.num_strides=2 \ backbone=hf_dit

Train

To train MDLM from scratch on OpenWebText use the following command: python main.py \ model=small \ data=openwebtext-split \ wandb.name=mdlm-owt \ parameterization=subs \ model.length=1024 \ eval.compute_generative_perplexity=True \ sampling.steps=1000 The arguments loader.batch_size and loader.eval_batch_size allow you to control the global batch size and the batch size per GPU. If loader.batch_size * num_gpus is less than the global batch size, PyTorch Lightning will resort to gradient accumulation. You can also launch a training job on Slurm using the command: sbatch scripts/train_owt_mdlm.sh. The slurm scripts to train the Auto-regressive and SEDD baselines are as follows respectively: scripts/train_lm1b_ar.sh, scripts/train_owt_sedd.sh.

Eval

To compute test perplexity, use mode=ppl_eval. Example scripts provided in scripts/. An example command for perplexity evaluation on OpenWebText is: python main.py \ mode=ppl_eval \ loader.batch_size=16 \ loader.eval_batch_size=16 \ data=openwebtext-split \ model=small \ parameterization=subs \ backbone=dit \ model.length=1024 \ eval.checkpoint_path=/path/to/checkpoint/mdlm.ckpt \ +wandb.offline=true

Baseline evaluation

We release the checkpoints for the baselines: SEDD and AR trained on OpenWebText in this Google Drive folder. Download the checkpoints: ar.ckpt, sedd.ckpt and use the following commands to compute test perplexity:

AR

bash python main.py \ mode=ppl_eval \ loader.batch_size=16 \ loader.eval_batch_size=16 \ data=openwebtext-split \ model=small-ar \ parameterization=ar \ backbone=ar \ model.length=1024 \ eval.checkpoint_path=/path/to/checkpoint/ar.ckpt \ +wandb.offline=true

SEDD

bash python main.py \ mode=ppl_eval \ loader.batch_size=16 \ loader.eval_batch_size=16 \ data=openwebtext-split \ model=small \ parameterization=sedd \ backbone=dit \ model.length=1024 \ eval.checkpoint_path=/path/to/checkpoint/sedd.ckpt \ time_conditioning=True \ sampling.predictor=analytic \ +wandb.offline=true

Acknowledgements

This repository was built off of SEDD.

Citation

@inproceedings{ sahoo2024simple, title={Simple and Effective Masked Diffusion Language Models}, author={Subham Sekhar Sahoo and Marianne Arriola and Aaron Gokaslan and Edgar Mariano Marroquin and Alexander M Rush and Yair Schiff and Justin T Chiu and Volodymyr Kuleshov}, booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems}, year={2024}, url={https://openreview.net/forum?id=L4uaAR4ArM} }

Owner

  • Name: Kuleshov Group @ Cornell Tech
  • Login: kuleshov-group
  • Kind: organization

Research group at Cornell focused on machine learning, generative models, AI for science

Citation (CITATION.cff)

cff-version: 1.2.0
message: "If you use this software, please cite it as below."
title: "Simple and Effective Masked Diffusion Language Models"
doi: "10.48550/arXiv.2406.07524"
authors:
  - family-names: "Sahoo"
    given-names: "Subham Sekhar"
  - family-names: "Arriola"
    given-names: "Marianne"
  - family-names: "Schiff"
    given-names: "Yair"
  - family-names: "Gokaslan"
    given-names: "Aaron"
  - family-names: "Marroquin"
    given-names: "Edgar"
  - family-names: "Chiu"
    given-names: "Justin T"
  - family-names: "Rush"
    given-names: "Alexander"
  - family-names: "Kuleshov"
    given-names: "Volodymyr"
date-released: 2024-06-11
version: "arXiv:2406.07524v1"

GitHub Events

Total
  • Issues event: 33
  • Watch event: 253
  • Issue comment event: 21
  • Push event: 9
  • Pull request event: 4
  • Fork event: 45
Last Year
  • Issues event: 33
  • Watch event: 253
  • Issue comment event: 21
  • Push event: 9
  • Pull request event: 4
  • Fork event: 45

Issues and Pull Requests

Last synced: 4 months ago

All Time
  • Total issues: 16
  • Total pull requests: 1
  • Average time to close issues: 15 days
  • Average time to close pull requests: 3 days
  • Total issue authors: 16
  • Total pull request authors: 1
  • Average comments per issue: 0.94
  • Average comments per pull request: 0.0
  • Merged pull requests: 1
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 16
  • Pull requests: 1
  • Average time to close issues: 15 days
  • Average time to close pull requests: 3 days
  • Issue authors: 16
  • Pull request authors: 1
  • Average comments per issue: 0.94
  • Average comments per pull request: 0.0
  • Merged pull requests: 1
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • albertotono (4)
  • YuxuanSong (1)
  • yuanzhi-zhu (1)
  • alxmrs (1)
  • ClemChou000 (1)
  • ashwinipokle (1)
  • frederick0329 (1)
  • chaofan520 (1)
  • 241416 (1)
  • cuiheling (1)
  • JoJo0217 (1)
  • enkeejunior1 (1)
  • NamburiSrinath (1)
  • yusowa0716 (1)
  • dirkweissenborn (1)
Pull Request Authors
  • kuleshov (3)
  • Borda (1)
Top Labels
Issue Labels
Pull Request Labels