https://github.com/amazon-science/masked-diffusion-lm
Official implementation for the paper "A Cheaper and Better Diffusion Language Model with Soft-Masked Noise"
Science Score: 26.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
○CITATION.cff file
-
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
○Academic publication links
-
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (7.2%) to scientific vocabulary
Repository
Official implementation for the paper "A Cheaper and Better Diffusion Language Model with Soft-Masked Noise"
Basic Info
Statistics
- Stars: 56
- Watchers: 3
- Forks: 4
- Open Issues: 20
- Releases: 0
Metadata Files
README.md
A Cheaper and Better Diffusion Language Model with Soft-Masked Noise
This is the official implementation of the paper: A Cheaper and Better Diffusion Language Model with Soft-Masked Noise.
Usage
One needs to setup the enrironment before running the experiments.
Conda Setup:
python
conda install mpi4py
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
pip install -e improved-diffusion/
pip install -e transformers/
pip install spacy==3.2.4
pip install datasets==1.8.0
pip install huggingface_hub==0.4.0
pip install wandb
Train Masked-Diffusion-LM:
cd improved-diffusion; mkdir diffusion_models;
python scripts/run_train.py --diff_steps 500 --model_arch bert --lr 0.0003 --lr_anneal_steps 400000 --seed 0 --noise_schedule sqrt --in_channel 128 --modality roc --submit no --padding_mode pad --app "--predict_xstart True --training_mode masked-diffuse-lm --roc_train ../datasets/ROCstory " --bsz 64
Decode Diffusion-LM:
mkdir generation_outputs
python scripts/batch_decode.py {path-to-diffusion-lm} -1.0 ema
Controllable Text Generation
First, train the classsifier used to guide the generation (e.g. a syntactic parser)
python train_run.py --experiment e2e-tgt-tree --app "--init_emb {path-to-diffusion-lm} --n_embd {16} --learned_emb yes " --pretrained_model bert-base-uncased --epoch 6 --bsz 10
Then, we can use the trained classifier to guide generation. (currently, need to update the classifier directory in scripts/infill.py. I will clean this up in the next release.)
python
python scripts/infill.py --model_path {path-to-diffusion-lm} --eval_task_ 'control_tree' --use_ddim True --notes "tree_adagrad" --eta 1. --verbose pipe
Acknowledgement
Part of our codes are adapted from Diffusion-LM and Transformers.
License
This project is licensed under the Apache-2.0 License.
Owner
- Name: Amazon Science
- Login: amazon-science
- Kind: organization
- Website: https://amazon.science
- Twitter: AmazonScience
- Repositories: 80
- Profile: https://github.com/amazon-science
GitHub Events
Total
- Watch event: 8
- Issue comment event: 1
- Fork event: 1
Last Year
- Watch event: 8
- Issue comment event: 1
- Fork event: 1
Issues and Pull Requests
Last synced: over 1 year ago
All Time
- Total issues: 9
- Total pull requests: 80
- Average time to close issues: N/A
- Average time to close pull requests: 15 days
- Total issue authors: 5
- Total pull request authors: 1
- Average comments per issue: 1.11
- Average comments per pull request: 0.05
- Merged pull requests: 47
- Bot issues: 0
- Bot pull requests: 80
Past Year
- Issues: 1
- Pull requests: 0
- Average time to close issues: N/A
- Average time to close pull requests: N/A
- Issue authors: 1
- Pull request authors: 0
- Average comments per issue: 0.0
- Average comments per pull request: 0
- Merged pull requests: 0
- Bot issues: 0
- Bot pull requests: 0
Top Authors
Issue Authors
- Thebodyoflake (1)
- baiyuting (1)
- Orange1999 (1)
- ai-agi (1)
- vikyou (1)
Pull Request Authors
- dependabot[bot] (35)
Top Labels
Issue Labels
Pull Request Labels
Dependencies
- ubuntu 18.04 build
- nvidia/cuda 10.2-cudnn7-devel-ubuntu18.04 build
- ubuntu 18.04 build
- nvidia/cuda 10.2-cudnn7-devel-ubuntu18.04 build
- google/cloud-sdk slim build
- ubuntu 18.04 build
- nvidia/cuda 10.1-cudnn7-runtime-ubuntu18.04 build
- nvcr.io/nvidia/pytorch 21.07-py3 build
- blobfile *
- conllu * test
- datasets >=1.1.3 test
- nltk * test
- pytest * test
- rouge-score * test
- seqeval * test
- tensorboard * test
- datasets >=1.1.3
- flax >=0.3.5
- jax >=0.2.8
- jaxlib >=0.1.59
- optax >=0.0.9
- datasets >=1.8.0
- flax >=0.3.5
- jax >=0.2.17
- jaxlib >=0.1.68
- optax >=0.0.8
- datasets >=1.1.3
- flax >=0.3.5
- jax >=0.2.8
- jaxlib >=0.1.59
- optax >=0.0.8
- datasets >=1.1.3
- flax >=0.3.5
- jax >=0.2.8
- jaxlib >=0.1.59
- optax >=0.0.8
- datasets >=1.8.0
- flax >=0.3.5
- jax >=0.2.8
- jaxlib >=0.1.59
- optax >=0.0.8
- seqeval *
- flax >=0.3.5
- jax >=0.2.8
- jaxlib >=0.1.59
- optax >=0.0.8
- torch ==1.13.1
- torchvision ==0.10.0
- conllu *
- datasets >=1.1.3
- elasticsearch *
- faiss-cpu *
- fire *
- git-python ==1.0.3
- matplotlib *
- nltk *
- pandas *
- protobuf *
- psutil *
- pytest *
- ray *
- rouge-score *
- sacrebleu *
- scikit-learn *
- sentencepiece *
- seqeval *
- streamlit *
- tensorboard *
- tensorflow_datasets *
- conllu *
- datasets >=1.1.3
- elasticsearch *
- faiss-cpu *
- fire *
- git-python ==1.0.3
- matplotlib *
- nltk *
- pandas *
- protobuf *
- psutil *
- pytest *
- rouge-score *
- sacrebleu *
- scikit-learn *
- sentencepiece *
- seqeval *
- streamlit *
- tensorboard *
- tensorflow_datasets *
- accelerate >=0.5.0 test
- conllu * test
- datasets >=1.13.3 test
- elasticsearch * test
- faiss-cpu * test
- fire * test
- git-python ==1.0.3 test
- jiwer * test
- librosa * test
- matplotlib * test
- nltk * test
- pandas * test
- protobuf * test
- psutil * test
- pytest * test
- rouge-score * test
- sacrebleu >=1.4.12 test
- scikit-learn * test
- sentencepiece * test
- seqeval * test
- streamlit * test
- tensorboard * test
- tensorflow_datasets * test
- torchvision * test
- datasets >=1.14.0
- librosa *
- torch >=1.6
- torchaudio *
- torch >=1.3
- datasets >=1.8.0
- torch >=1.5.0
- torchvision >=0.6.0
- datasets >=1.8.0
- torch >=1.5.0
- torchvision >=0.6.0
- accelerate *
- datasets >=1.8.0
- protobuf *
- sentencepiece *
- torch >=1.3
- accelerate *
- protobuf *
- sentencepiece *
- torch >=1.3
- accelerate *
- datasets >=1.8.0
- torch >=1.3.0
- accelerate >=0.5.0
- datasets >=1.12.0
- librosa *
- torch >=1.5
- torchaudio *
- datasets >=1.13.3
- jiwer *
- librosa *
- torch >=1.5
- torchaudio *
- accelerate *
- datasets >=1.8.0
- nltk *
- protobuf *
- py7zr *
- rouge-score *
- sentencepiece *
- torch >=1.3
- accelerate *
- datasets >=1.8.0
- protobuf *
- scikit-learn *
- scipy *
- sentencepiece *
- torch >=1.3
- protobuf *
- sentencepiece *
- torch >=1.3
- accelerate *
- datasets >=1.8.0
- seqeval *
- torch >=1.3
- accelerate *
- datasets >=1.8.0
- protobuf *
- py7zr *
- sacrebleu >=1.4.12
- sentencepiece *
- torch >=1.3
- transformers ==3.5.1
- transformers ==3.5.1
- nltk *
- py-rouge *
- transformers ==3.5.1
- transformers ==3.5.1
- accelerate ==0.5.1
- datasets ==1.16.0
- huggingface-hub ==0.1.0
- tensorboard ==2.6.0
- torch ==1.13.1
- transformers ==4.15.0
- wandb ==0.12.0
- transformers ==3.5.1
- gitpython ==3.1.30
- psutil ==5.6.6
- scipy >=1.4.1
- tensorboard >=1.14.0
- tensorboardX ==1.8
- transformers *
- transformers >=4.9.2
- torch >=1.9.0
- datasets *
- flax *
- jsonlines *
- sentencepiece *
- wandb *
- flax >=0.3.5
- jax >=0.2.8
- jaxlib >=0.1.59
- optax >=0.0.8
- torch ==1.13.1
- torchvision ==0.10.0
- datasets >=1.1.3
- elasticsearch *
- faiss-cpu *
- streamlit *
- CacheControl ==0.12.6
- Jinja2 >=2.11.3
- MarkupSafe ==1.1.1
- Pillow >=8.1.1
- PyYAML >=5.4
- Pygments >=2.7.4
- QtPy ==1.9.0
- Send2Trash ==1.5.0
- appdirs ==1.4.3
- argon2-cffi ==20.1.0
- async-generator ==1.10
- attrs ==20.2.0
- backcall ==0.2.0
- certifi ==2022.12.7
- cffi ==1.14.2
- chardet ==3.0.4
- click ==7.1.2
- colorama ==0.4.3
- contextlib2 ==0.6.0
- cycler ==0.10.0
- datasets ==1.0.0
- decorator ==4.4.2
- defusedxml ==0.6.0
- dill ==0.3.2
- distlib ==0.3.0
- distro ==1.4.0
- entrypoints ==0.3
- filelock ==3.0.12
- future ==0.18.3
- html5lib ==1.0.1
- idna ==2.8
- ipaddr ==2.2.0
- ipykernel ==5.3.4
- ipython *
- ipython-genutils ==0.2.0
- ipywidgets ==7.5.1
- jedi ==0.17.2
- joblib ==1.1.1
- jsonschema ==3.2.0
- jupyter ==1.0.0
- jupyter-client ==6.1.7
- jupyter-console ==6.2.0
- jupyter-core ==4.11.2
- jupyterlab-pygments ==0.1.1
- kiwisolver ==1.2.0
- lockfile ==0.12.2
- matplotlib ==3.3.1
- mistune ==0.8.4
- msgpack ==0.6.2
- nbclient ==0.5.0
- nbconvert ==6.5.1
- nbformat ==5.0.7
- nest-asyncio ==1.4.0
- notebook ==6.4.12
- numpy ==1.22.0
- opencv-python ==4.4.0.42
- packaging ==20.3
- pandas ==1.1.2
- pandocfilters ==1.4.2
- parso ==0.7.1
- pep517 ==0.8.2
- pexpect ==4.8.0
- pickleshare ==0.7.5
- progress ==1.5
- prometheus-client ==0.8.0
- prompt-toolkit ==3.0.7
- ptyprocess ==0.6.0
- pyaml ==20.4.0
- pyarrow ==1.0.1
- pycparser ==2.20
- pyparsing ==2.4.6
- pyrsistent ==0.16.0
- python-dateutil ==2.8.1
- pytoml ==0.1.21
- pytz ==2020.1
- pyzmq ==19.0.2
- qtconsole ==4.7.7
- regex ==2020.7.14
- requests ==2.22.0
- retrying ==1.3.3
- sacremoses ==0.0.43
- sentencepiece ==0.1.91
- six ==1.14.0
- terminado ==0.8.3
- testpath ==0.4.4
- tokenizers ==0.8.1rc2
- torch ==1.13.1
- torchvision ==0.7.0
- tornado ==6.0.4
- tqdm ==4.48.2
- traitlets *
- urllib3 ==1.26.5
- wcwidth ==0.2.5
- webencodings ==0.5.1
- wget ==3.2
- widgetsnbextension ==3.5.1
- xxhash ==2.0.0
- datasets >=1.1.3
- ltp *
- protobuf *
- sentencepiece *
- h5py >=2.10.0
- knockknock >=0.1.8.1
- numpy >=1.18.2
- scipy >=1.4.1
- torch >=1.4.0
- torch >=1.10
- conllu *
- datasets >=1.1.3
- elasticsearch *
- faiss-cpu *
- fire *
- git-python ==1.0.3
- matplotlib *
- nltk *
- pandas *
- protobuf *
- psutil *
- pytest *
- pytorch-lightning ==1.6.0
- rouge-score *
- sacrebleu *
- scikit-learn *
- sentencepiece *
- seqeval *
- streamlit *
- tensorboard *
- tensorflow_datasets *
- transformers ==3.5.1
- GitPython *
- datasets >=1.0.1
- faiss-cpu >=1.6.3
- psutil >=5.7.0
- pytorch-lightning ==1.6.0
- torch >=1.4.0
- transformers *
- datasets >=1.6.2
- faiss-cpu >=1.7.0
- nvidia-ml-py3 ==7.352.0
- psutil >=5.7.0
- pytorch-lightning ==1.6.0
- ray >=1.3.0
- torch >=1.4.0
- conllu *
- datasets >=1.1.3
- elasticsearch *
- faiss-cpu *
- fire *
- git-python ==1.0.3
- matplotlib *
- nltk *
- pandas *
- protobuf *
- psutil *
- pytest *
- pytorch-lightning ==1.6.0
- rouge-score *
- sacrebleu *
- scikit-learn *
- sentencepiece *
- streamlit *
- tensorboard *
- tensorflow_datasets *
- CacheControl ==0.12.6
- Jinja2 >=2.11.3
- MarkupSafe ==1.1.1
- Pillow >=8.1.1
- PyYAML >=5.4
- Pygments >=2.7.4
- QtPy ==1.9.0
- Send2Trash ==1.5.0
- appdirs ==1.4.3
- argon2-cffi ==20.1.0
- async-generator ==1.10
- attrs ==20.2.0
- backcall ==0.2.0
- certifi ==2022.12.7
- cffi ==1.14.2
- chardet ==3.0.4
- click ==7.1.2
- colorama ==0.4.3
- contextlib2 ==0.6.0
- cycler ==0.10.0
- datasets ==1.0.0
- decorator ==4.4.2
- defusedxml ==0.6.0
- dill ==0.3.2
- distlib ==0.3.0
- distro ==1.4.0
- entrypoints ==0.3
- filelock ==3.0.12
- future ==0.18.3
- html5lib ==1.0.1
- idna ==2.8
- ipaddr ==2.2.0
- ipykernel ==5.3.4
- ipython *
- ipython-genutils ==0.2.0
- ipywidgets ==7.5.1
- jedi ==0.17.2
- joblib ==1.2.0
- jsonschema ==3.2.0
- jupyter ==1.0.0
- jupyter-client ==6.1.7
- jupyter-console ==6.2.0
- jupyter-core ==4.11.2
- jupyterlab-pygments ==0.1.1
- kiwisolver ==1.2.0
- lockfile ==0.12.2
- matplotlib ==3.3.1
- mistune ==0.8.4
- msgpack ==0.6.2
- nbclient ==0.5.0
- nbconvert ==6.5.1
- nbformat ==5.0.7
- nest-asyncio ==1.4.0
- notebook ==6.4.12
- numpy ==1.22.0
- opencv-python ==4.4.0.42
- packaging ==20.3
- pandas ==1.1.2
- pandocfilters ==1.4.2
- parso ==0.7.1
- pep517 ==0.8.2
- pexpect ==4.8.0
- pickleshare ==0.7.5
- progress ==1.5
- prometheus-client ==0.8.0
- prompt-toolkit ==3.0.7
- ptyprocess ==0.6.0
- pyaml ==20.4.0
- pyarrow ==1.0.1
- pycparser ==2.20
- pyparsing ==2.4.6
- pyrsistent ==0.16.0
- python-dateutil ==2.8.1
- pytoml ==0.1.21
- pytz ==2020.1
- pyzmq ==19.0.2
- qtconsole ==4.7.7
- regex ==2020.7.14
- requests ==2.22.0
- retrying ==1.3.3
- sacremoses ==0.0.43
- sentencepiece ==0.1.91
- six ==1.14.0
- terminado ==0.8.3
- testpath ==0.4.4
- tokenizers ==0.8.1rc2
- torch ==1.13.1
- torchvision ==0.7.0
- tornado ==6.0.4
- tqdm ==4.48.2
- traitlets *
- urllib3 ==1.26.5
- wcwidth ==0.2.5
- webencodings ==0.5.1
- wget ==3.2
- widgetsnbextension ==3.5.1
- xxhash ==2.0.0
- datasets *
- jiwer ==2.2.0
- lang-trans ==0.6.0
- librosa ==0.8.0
- torch >=1.5.0
- torchaudio *
- transformers *
- tensorflow >=2.3
- datasets >=1.8.0
- sentencepiece *
- protobuf *
- sentencepiece *
- tensorflow >=2.3
- datasets >=1.4.0
- tensorflow >=2.3.0
- datasets >=1.1.3
- protobuf *
- sentencepiece *
- tensorflow >=2.3
- deps *
- datasets ==1.8.0 test