https://github.com/amazon-science/masked-diffusion-lm

Official implementation for the paper "A Cheaper and Better Diffusion Language Model with Soft-Masked Noise"

https://github.com/amazon-science/masked-diffusion-lm

Science Score: 26.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (7.2%) to scientific vocabulary
Last synced: 9 months ago · JSON representation

Repository

Official implementation for the paper "A Cheaper and Better Diffusion Language Model with Soft-Masked Noise"

Basic Info
  • Host: GitHub
  • Owner: amazon-science
  • License: apache-2.0
  • Language: Python
  • Default Branch: main
  • Homepage:
  • Size: 8.16 MB
Statistics
  • Stars: 56
  • Watchers: 3
  • Forks: 4
  • Open Issues: 20
  • Releases: 0
Created about 3 years ago · Last pushed almost 3 years ago
Metadata Files
Readme Contributing License Code of conduct

README.md

A Cheaper and Better Diffusion Language Model with Soft-Masked Noise

This is the official implementation of the paper: A Cheaper and Better Diffusion Language Model with Soft-Masked Noise.


Usage

One needs to setup the enrironment before running the experiments.

Conda Setup:

python conda install mpi4py conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch pip install -e improved-diffusion/ pip install -e transformers/ pip install spacy==3.2.4 pip install datasets==1.8.0 pip install huggingface_hub==0.4.0 pip install wandb


Train Masked-Diffusion-LM:

cd improved-diffusion; mkdir diffusion_models;

python scripts/run_train.py --diff_steps 500 --model_arch bert --lr 0.0003 --lr_anneal_steps 400000 --seed 0 --noise_schedule sqrt --in_channel 128 --modality roc --submit no --padding_mode pad --app "--predict_xstart True --training_mode masked-diffuse-lm --roc_train ../datasets/ROCstory " --bsz 64


Decode Diffusion-LM:

mkdir generation_outputs

python scripts/batch_decode.py {path-to-diffusion-lm} -1.0 ema


Controllable Text Generation

First, train the classsifier used to guide the generation (e.g. a syntactic parser)

python train_run.py --experiment e2e-tgt-tree --app "--init_emb {path-to-diffusion-lm} --n_embd {16} --learned_emb yes " --pretrained_model bert-base-uncased --epoch 6 --bsz 10

Then, we can use the trained classifier to guide generation. (currently, need to update the classifier directory in scripts/infill.py. I will clean this up in the next release.)

python python scripts/infill.py --model_path {path-to-diffusion-lm} --eval_task_ 'control_tree' --use_ddim True --notes "tree_adagrad" --eta 1. --verbose pipe

Acknowledgement

Part of our codes are adapted from Diffusion-LM and Transformers.

License

This project is licensed under the Apache-2.0 License.

Owner

  • Name: Amazon Science
  • Login: amazon-science
  • Kind: organization

GitHub Events

Total
  • Watch event: 8
  • Issue comment event: 1
  • Fork event: 1
Last Year
  • Watch event: 8
  • Issue comment event: 1
  • Fork event: 1

Issues and Pull Requests

Last synced: over 1 year ago

All Time
  • Total issues: 9
  • Total pull requests: 80
  • Average time to close issues: N/A
  • Average time to close pull requests: 15 days
  • Total issue authors: 5
  • Total pull request authors: 1
  • Average comments per issue: 1.11
  • Average comments per pull request: 0.05
  • Merged pull requests: 47
  • Bot issues: 0
  • Bot pull requests: 80
Past Year
  • Issues: 1
  • Pull requests: 0
  • Average time to close issues: N/A
  • Average time to close pull requests: N/A
  • Issue authors: 1
  • Pull request authors: 0
  • Average comments per issue: 0.0
  • Average comments per pull request: 0
  • Merged pull requests: 0
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • Thebodyoflake (1)
  • baiyuting (1)
  • Orange1999 (1)
  • ai-agi (1)
  • vikyou (1)
Pull Request Authors
  • dependabot[bot] (35)
Top Labels
Issue Labels
Pull Request Labels
dependencies (35)

Dependencies

transformers/docker/transformers-cpu/Dockerfile docker
  • ubuntu 18.04 build
transformers/docker/transformers-gpu/Dockerfile docker
  • nvidia/cuda 10.2-cudnn7-devel-ubuntu18.04 build
transformers/docker/transformers-pytorch-cpu/Dockerfile docker
  • ubuntu 18.04 build
transformers/docker/transformers-pytorch-gpu/Dockerfile docker
  • nvidia/cuda 10.2-cudnn7-devel-ubuntu18.04 build
transformers/docker/transformers-pytorch-tpu/Dockerfile docker
  • google/cloud-sdk slim build
transformers/docker/transformers-tensorflow-cpu/Dockerfile docker
  • ubuntu 18.04 build
transformers/docker/transformers-tensorflow-gpu/Dockerfile docker
  • nvidia/cuda 10.1-cudnn7-runtime-ubuntu18.04 build
transformers/examples/research_projects/quantization-qdqbert/Dockerfile docker
  • nvcr.io/nvidia/pytorch 21.07-py3 build
improved-diffusion/setup.py pypi
  • blobfile *
transformers/examples/flax/_tests_requirements.txt pypi
  • conllu * test
  • datasets >=1.1.3 test
  • nltk * test
  • pytest * test
  • rouge-score * test
  • seqeval * test
  • tensorboard * test
transformers/examples/flax/language-modeling/requirements.txt pypi
  • datasets >=1.1.3
  • flax >=0.3.5
  • jax >=0.2.8
  • jaxlib >=0.1.59
  • optax >=0.0.9
transformers/examples/flax/question-answering/requirements.txt pypi
  • datasets >=1.8.0
  • flax >=0.3.5
  • jax >=0.2.17
  • jaxlib >=0.1.68
  • optax >=0.0.8
transformers/examples/flax/summarization/requirements.txt pypi
  • datasets >=1.1.3
  • flax >=0.3.5
  • jax >=0.2.8
  • jaxlib >=0.1.59
  • optax >=0.0.8
transformers/examples/flax/text-classification/requirements.txt pypi
  • datasets >=1.1.3
  • flax >=0.3.5
  • jax >=0.2.8
  • jaxlib >=0.1.59
  • optax >=0.0.8
transformers/examples/flax/token-classification/requirements.txt pypi
  • datasets >=1.8.0
  • flax >=0.3.5
  • jax >=0.2.8
  • jaxlib >=0.1.59
  • optax >=0.0.8
  • seqeval *
transformers/examples/flax/vision/requirements.txt pypi
  • flax >=0.3.5
  • jax >=0.2.8
  • jaxlib >=0.1.59
  • optax >=0.0.8
  • torch ==1.13.1
  • torchvision ==0.10.0
transformers/examples/legacy/pytorch-lightning/requirements.txt pypi
  • conllu *
  • datasets >=1.1.3
  • elasticsearch *
  • faiss-cpu *
  • fire *
  • git-python ==1.0.3
  • matplotlib *
  • nltk *
  • pandas *
  • protobuf *
  • psutil *
  • pytest *
  • ray *
  • rouge-score *
  • sacrebleu *
  • scikit-learn *
  • sentencepiece *
  • seqeval *
  • streamlit *
  • tensorboard *
  • tensorflow_datasets *
transformers/examples/legacy/seq2seq/requirements.txt pypi
  • conllu *
  • datasets >=1.1.3
  • elasticsearch *
  • faiss-cpu *
  • fire *
  • git-python ==1.0.3
  • matplotlib *
  • nltk *
  • pandas *
  • protobuf *
  • psutil *
  • pytest *
  • rouge-score *
  • sacrebleu *
  • scikit-learn *
  • sentencepiece *
  • seqeval *
  • streamlit *
  • tensorboard *
  • tensorflow_datasets *
transformers/examples/pytorch/_tests_requirements.txt pypi
  • accelerate >=0.5.0 test
  • conllu * test
  • datasets >=1.13.3 test
  • elasticsearch * test
  • faiss-cpu * test
  • fire * test
  • git-python ==1.0.3 test
  • jiwer * test
  • librosa * test
  • matplotlib * test
  • nltk * test
  • pandas * test
  • protobuf * test
  • psutil * test
  • pytest * test
  • rouge-score * test
  • sacrebleu >=1.4.12 test
  • scikit-learn * test
  • sentencepiece * test
  • seqeval * test
  • streamlit * test
  • tensorboard * test
  • tensorflow_datasets * test
  • torchvision * test
transformers/examples/pytorch/audio-classification/requirements.txt pypi
  • datasets >=1.14.0
  • librosa *
  • torch >=1.6
  • torchaudio *
transformers/examples/pytorch/benchmarking/requirements.txt pypi
  • torch >=1.3
transformers/examples/pytorch/image-classification/requirements.txt pypi
  • datasets >=1.8.0
  • torch >=1.5.0
  • torchvision >=0.6.0
transformers/examples/pytorch/image-pretraining/requirements.txt pypi
  • datasets >=1.8.0
  • torch >=1.5.0
  • torchvision >=0.6.0
transformers/examples/pytorch/language-modeling/requirements.txt pypi
  • accelerate *
  • datasets >=1.8.0
  • protobuf *
  • sentencepiece *
  • torch >=1.3
transformers/examples/pytorch/multiple-choice/requirements.txt pypi
  • accelerate *
  • protobuf *
  • sentencepiece *
  • torch >=1.3
transformers/examples/pytorch/question-answering/requirements.txt pypi
  • accelerate *
  • datasets >=1.8.0
  • torch >=1.3.0
transformers/examples/pytorch/speech-pretraining/requirements.txt pypi
  • accelerate >=0.5.0
  • datasets >=1.12.0
  • librosa *
  • torch >=1.5
  • torchaudio *
transformers/examples/pytorch/speech-recognition/requirements.txt pypi
  • datasets >=1.13.3
  • jiwer *
  • librosa *
  • torch >=1.5
  • torchaudio *
transformers/examples/pytorch/summarization/requirements.txt pypi
  • accelerate *
  • datasets >=1.8.0
  • nltk *
  • protobuf *
  • py7zr *
  • rouge-score *
  • sentencepiece *
  • torch >=1.3
transformers/examples/pytorch/text-classification/requirements.txt pypi
  • accelerate *
  • datasets >=1.8.0
  • protobuf *
  • scikit-learn *
  • scipy *
  • sentencepiece *
  • torch >=1.3
transformers/examples/pytorch/text-generation/requirements.txt pypi
  • protobuf *
  • sentencepiece *
  • torch >=1.3
transformers/examples/pytorch/token-classification/requirements.txt pypi
  • accelerate *
  • datasets >=1.8.0
  • seqeval *
  • torch >=1.3
transformers/examples/pytorch/translation/requirements.txt pypi
  • accelerate *
  • datasets >=1.8.0
  • protobuf *
  • py7zr *
  • sacrebleu >=1.4.12
  • sentencepiece *
  • torch >=1.3
transformers/examples/research_projects/adversarial/requirements.txt pypi
  • transformers ==3.5.1
transformers/examples/research_projects/bert-loses-patience/requirements.txt pypi
  • transformers ==3.5.1
transformers/examples/research_projects/bertabs/requirements.txt pypi
  • nltk *
  • py-rouge *
  • transformers ==3.5.1
transformers/examples/research_projects/bertology/requirements.txt pypi
  • transformers ==3.5.1
transformers/examples/research_projects/codeparrot/requirements.txt pypi
  • accelerate ==0.5.1
  • datasets ==1.16.0
  • huggingface-hub ==0.1.0
  • tensorboard ==2.6.0
  • torch ==1.13.1
  • transformers ==4.15.0
  • wandb ==0.12.0
transformers/examples/research_projects/deebert/requirements.txt pypi
  • transformers ==3.5.1
transformers/examples/research_projects/distillation/requirements.txt pypi
  • gitpython ==3.1.30
  • psutil ==5.6.6
  • scipy >=1.4.1
  • tensorboard >=1.14.0
  • tensorboardX ==1.8
  • transformers *
transformers/examples/research_projects/fsner/requirements.txt pypi
  • transformers >=4.9.2
transformers/examples/research_projects/fsner/setup.py pypi
  • torch >=1.9.0
transformers/examples/research_projects/jax-projects/big_bird/requirements.txt pypi
  • datasets *
  • flax *
  • jsonlines *
  • sentencepiece *
  • wandb *
transformers/examples/research_projects/jax-projects/hybrid_clip/requirements.txt pypi
  • flax >=0.3.5
  • jax >=0.2.8
  • jaxlib >=0.1.59
  • optax >=0.0.8
  • torch ==1.13.1
  • torchvision ==0.10.0
transformers/examples/research_projects/longform-qa/requirements.txt pypi
  • datasets >=1.1.3
  • elasticsearch *
  • faiss-cpu *
  • streamlit *
transformers/examples/research_projects/lxmert/requirements.txt pypi
  • CacheControl ==0.12.6
  • Jinja2 >=2.11.3
  • MarkupSafe ==1.1.1
  • Pillow >=8.1.1
  • PyYAML >=5.4
  • Pygments >=2.7.4
  • QtPy ==1.9.0
  • Send2Trash ==1.5.0
  • appdirs ==1.4.3
  • argon2-cffi ==20.1.0
  • async-generator ==1.10
  • attrs ==20.2.0
  • backcall ==0.2.0
  • certifi ==2022.12.7
  • cffi ==1.14.2
  • chardet ==3.0.4
  • click ==7.1.2
  • colorama ==0.4.3
  • contextlib2 ==0.6.0
  • cycler ==0.10.0
  • datasets ==1.0.0
  • decorator ==4.4.2
  • defusedxml ==0.6.0
  • dill ==0.3.2
  • distlib ==0.3.0
  • distro ==1.4.0
  • entrypoints ==0.3
  • filelock ==3.0.12
  • future ==0.18.3
  • html5lib ==1.0.1
  • idna ==2.8
  • ipaddr ==2.2.0
  • ipykernel ==5.3.4
  • ipython *
  • ipython-genutils ==0.2.0
  • ipywidgets ==7.5.1
  • jedi ==0.17.2
  • joblib ==1.1.1
  • jsonschema ==3.2.0
  • jupyter ==1.0.0
  • jupyter-client ==6.1.7
  • jupyter-console ==6.2.0
  • jupyter-core ==4.11.2
  • jupyterlab-pygments ==0.1.1
  • kiwisolver ==1.2.0
  • lockfile ==0.12.2
  • matplotlib ==3.3.1
  • mistune ==0.8.4
  • msgpack ==0.6.2
  • nbclient ==0.5.0
  • nbconvert ==6.5.1
  • nbformat ==5.0.7
  • nest-asyncio ==1.4.0
  • notebook ==6.4.12
  • numpy ==1.22.0
  • opencv-python ==4.4.0.42
  • packaging ==20.3
  • pandas ==1.1.2
  • pandocfilters ==1.4.2
  • parso ==0.7.1
  • pep517 ==0.8.2
  • pexpect ==4.8.0
  • pickleshare ==0.7.5
  • progress ==1.5
  • prometheus-client ==0.8.0
  • prompt-toolkit ==3.0.7
  • ptyprocess ==0.6.0
  • pyaml ==20.4.0
  • pyarrow ==1.0.1
  • pycparser ==2.20
  • pyparsing ==2.4.6
  • pyrsistent ==0.16.0
  • python-dateutil ==2.8.1
  • pytoml ==0.1.21
  • pytz ==2020.1
  • pyzmq ==19.0.2
  • qtconsole ==4.7.7
  • regex ==2020.7.14
  • requests ==2.22.0
  • retrying ==1.3.3
  • sacremoses ==0.0.43
  • sentencepiece ==0.1.91
  • six ==1.14.0
  • terminado ==0.8.3
  • testpath ==0.4.4
  • tokenizers ==0.8.1rc2
  • torch ==1.13.1
  • torchvision ==0.7.0
  • tornado ==6.0.4
  • tqdm ==4.48.2
  • traitlets *
  • urllib3 ==1.26.5
  • wcwidth ==0.2.5
  • webencodings ==0.5.1
  • wget ==3.2
  • widgetsnbextension ==3.5.1
  • xxhash ==2.0.0
transformers/examples/research_projects/mlm_wwm/requirements.txt pypi
  • datasets >=1.1.3
  • ltp *
  • protobuf *
  • sentencepiece *
transformers/examples/research_projects/movement-pruning/requirements.txt pypi
  • h5py >=2.10.0
  • knockknock >=0.1.8.1
  • numpy >=1.18.2
  • scipy >=1.4.1
  • torch >=1.4.0
transformers/examples/research_projects/onnx/summarization/requirements.txt pypi
  • torch >=1.10
transformers/examples/research_projects/pplm/requirements.txt pypi
  • conllu *
  • datasets >=1.1.3
  • elasticsearch *
  • faiss-cpu *
  • fire *
  • git-python ==1.0.3
  • matplotlib *
  • nltk *
  • pandas *
  • protobuf *
  • psutil *
  • pytest *
  • pytorch-lightning ==1.6.0
  • rouge-score *
  • sacrebleu *
  • scikit-learn *
  • sentencepiece *
  • seqeval *
  • streamlit *
  • tensorboard *
  • tensorflow_datasets *
  • transformers ==3.5.1
transformers/examples/research_projects/rag/requirements.txt pypi
  • GitPython *
  • datasets >=1.0.1
  • faiss-cpu >=1.6.3
  • psutil >=5.7.0
  • pytorch-lightning ==1.6.0
  • torch >=1.4.0
  • transformers *
transformers/examples/research_projects/rag-end2end-retriever/requirements.txt pypi
  • datasets >=1.6.2
  • faiss-cpu >=1.7.0
  • nvidia-ml-py3 ==7.352.0
  • psutil >=5.7.0
  • pytorch-lightning ==1.6.0
  • ray >=1.3.0
  • torch >=1.4.0
transformers/examples/research_projects/seq2seq-distillation/requirements.txt pypi
  • conllu *
  • datasets >=1.1.3
  • elasticsearch *
  • faiss-cpu *
  • fire *
  • git-python ==1.0.3
  • matplotlib *
  • nltk *
  • pandas *
  • protobuf *
  • psutil *
  • pytest *
  • pytorch-lightning ==1.6.0
  • rouge-score *
  • sacrebleu *
  • scikit-learn *
  • sentencepiece *
  • streamlit *
  • tensorboard *
  • tensorflow_datasets *
transformers/examples/research_projects/visual_bert/requirements.txt pypi
  • CacheControl ==0.12.6
  • Jinja2 >=2.11.3
  • MarkupSafe ==1.1.1
  • Pillow >=8.1.1
  • PyYAML >=5.4
  • Pygments >=2.7.4
  • QtPy ==1.9.0
  • Send2Trash ==1.5.0
  • appdirs ==1.4.3
  • argon2-cffi ==20.1.0
  • async-generator ==1.10
  • attrs ==20.2.0
  • backcall ==0.2.0
  • certifi ==2022.12.7
  • cffi ==1.14.2
  • chardet ==3.0.4
  • click ==7.1.2
  • colorama ==0.4.3
  • contextlib2 ==0.6.0
  • cycler ==0.10.0
  • datasets ==1.0.0
  • decorator ==4.4.2
  • defusedxml ==0.6.0
  • dill ==0.3.2
  • distlib ==0.3.0
  • distro ==1.4.0
  • entrypoints ==0.3
  • filelock ==3.0.12
  • future ==0.18.3
  • html5lib ==1.0.1
  • idna ==2.8
  • ipaddr ==2.2.0
  • ipykernel ==5.3.4
  • ipython *
  • ipython-genutils ==0.2.0
  • ipywidgets ==7.5.1
  • jedi ==0.17.2
  • joblib ==1.2.0
  • jsonschema ==3.2.0
  • jupyter ==1.0.0
  • jupyter-client ==6.1.7
  • jupyter-console ==6.2.0
  • jupyter-core ==4.11.2
  • jupyterlab-pygments ==0.1.1
  • kiwisolver ==1.2.0
  • lockfile ==0.12.2
  • matplotlib ==3.3.1
  • mistune ==0.8.4
  • msgpack ==0.6.2
  • nbclient ==0.5.0
  • nbconvert ==6.5.1
  • nbformat ==5.0.7
  • nest-asyncio ==1.4.0
  • notebook ==6.4.12
  • numpy ==1.22.0
  • opencv-python ==4.4.0.42
  • packaging ==20.3
  • pandas ==1.1.2
  • pandocfilters ==1.4.2
  • parso ==0.7.1
  • pep517 ==0.8.2
  • pexpect ==4.8.0
  • pickleshare ==0.7.5
  • progress ==1.5
  • prometheus-client ==0.8.0
  • prompt-toolkit ==3.0.7
  • ptyprocess ==0.6.0
  • pyaml ==20.4.0
  • pyarrow ==1.0.1
  • pycparser ==2.20
  • pyparsing ==2.4.6
  • pyrsistent ==0.16.0
  • python-dateutil ==2.8.1
  • pytoml ==0.1.21
  • pytz ==2020.1
  • pyzmq ==19.0.2
  • qtconsole ==4.7.7
  • regex ==2020.7.14
  • requests ==2.22.0
  • retrying ==1.3.3
  • sacremoses ==0.0.43
  • sentencepiece ==0.1.91
  • six ==1.14.0
  • terminado ==0.8.3
  • testpath ==0.4.4
  • tokenizers ==0.8.1rc2
  • torch ==1.13.1
  • torchvision ==0.7.0
  • tornado ==6.0.4
  • tqdm ==4.48.2
  • traitlets *
  • urllib3 ==1.26.5
  • wcwidth ==0.2.5
  • webencodings ==0.5.1
  • wget ==3.2
  • widgetsnbextension ==3.5.1
  • xxhash ==2.0.0
transformers/examples/research_projects/wav2vec2/requirements.txt pypi
  • datasets *
  • jiwer ==2.2.0
  • lang-trans ==0.6.0
  • librosa ==0.8.0
  • torch >=1.5.0
  • torchaudio *
  • transformers *
transformers/examples/tensorflow/benchmarking/requirements.txt pypi
  • tensorflow >=2.3
transformers/examples/tensorflow/language-modeling/requirements.txt pypi
  • datasets >=1.8.0
  • sentencepiece *
transformers/examples/tensorflow/multiple-choice/requirements.txt pypi
  • protobuf *
  • sentencepiece *
  • tensorflow >=2.3
transformers/examples/tensorflow/question-answering/requirements.txt pypi
  • datasets >=1.4.0
  • tensorflow >=2.3.0
transformers/examples/tensorflow/text-classification/requirements.txt pypi
  • datasets >=1.1.3
  • protobuf *
  • sentencepiece *
  • tensorflow >=2.3
transformers/setup.py pypi
  • deps *
transformers/tests/sagemaker/scripts/pytorch/requirements.txt pypi
  • datasets ==1.8.0 test
transformers/examples/research_projects/fsner/pyproject.toml pypi
transformers/pyproject.toml pypi
transformers/tests/sagemaker/scripts/tensorflow/requirements.txt pypi