staged-training
Staged Training for Transformer Language Models
Science Score: 54.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
✓Academic publication links
Links to: arxiv.org -
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (10.4%) to scientific vocabulary
Keywords
Repository
Staged Training for Transformer Language Models
Basic Info
Statistics
- Stars: 32
- Watchers: 5
- Forks: 2
- Open Issues: 2
- Releases: 0
Topics
Metadata Files
README.md
staged-training
In our paper Staged Training for Transformer Language Models, we propose a staged training setup that begins with a small model and incrementally increases the amount of compute used for training by applying a "growth operator" to increase the model depth and width. By initializing each stage with the output of the previous one, the training process effectively re-uses the compute from prior stages and becomes more efficient.
We release the reproducible code for the growth operator and evaluation scripts here.
Setup
The scripts in this repository require Python 3.7 or newer.
Once you have a suitable Python environment, first install PyTorch v1.9.0 according the official instructions. Then run
pip install -r requirements.txt
Growth Operator
Our growth operators (width/depth) each take as input the entire training state (including model parameters, optimizer state, learning rate schedule, etc.) and output a new training state from which training continues.
Please see the scripts/cheatsheet.txt for more examples on how to use the corresponding scripts.
For example, you can apply the width operator with:
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/gpt_pretrain.py \
--save_prefix final_gpt2_large_div2_width_check_bs512_lr0.0020_warmup3k_seqlen1024_debug \
--gpu_count -1 \
--model gpt2 \
--tokenizer gpt2 \
--batch_size 4 \
--grad_accum 32 \
--lr 0.002006911598778545 \
--warmup_steps 3000 \ \
--train_steps 250000 \
--val_every 50 \
--val_batches 50 \
--fp16 \
--seqlen 1024 \
--log_rate 10 \
--num_workers 4 \
--size GPT2_large_div2_width \
--random \
--resume final_runs/final_gpt2_large_div2_width_check_bs512_lr0.0021_warmup3k_seqlen1024_debug/checkpoint-xxx.ckpt \
--doubling weights
Or the depth operator with:
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/gpt_pretrain.py \
--save_prefix final_gpt2_large_div2_depthx2_check_bs512_lr0.0020_warmup3k_seqlen1024_debug \
--gpu_count -1 \
--model gpt2 \
--tokenizer gpt2 \
--batch_size 4 \
--grad_accum 32 \
--lr 0.002006911598778545 \
--warmup_steps 3000 \
--train_steps 250000 \
--val_every 50 \
--val_batches 50 \
--fp16 \
--seqlen 1024 \
--log_rate 10 \
--num_workers 4 \
--size GPT2_large_div2_depth \
--random \
--resume final_runs/final_gpt2_large_div2_depth_check_bs512_lr0.0020_warmup3k_seqlen1024_debug/checkpoint-epoch=0-step=6499.ckpt \
--doubling layers
Evaluation
Use evaluation/eval_wikitext.py or evaluation/eval_lambada.py to evaluate GPT-2 on one of the supported datasets. For example:
bash
python evaluation/eval_wikitext.py
Or using Docker:
bash
docker build -t evaluation:latest .
docker run --rm --gpus all evaluation:latest evaluation/eval_wikitext.py
Reference
If you use staged training in your research or wish to refer to the baseline results published here,
please use the following BibTeX entry.
@misc{shen2022staged,
title={Staged Training for Transformer Language Models},
author={Sheng Shen and Pete Walsh and Kurt Keutzer and Jesse Dodge and Matthew Peters and Iz Beltagy},
year={2022},
eprint={2203.06211},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
Owner
- Name: AI2
- Login: allenai
- Kind: organization
- Email: ai2-info@allenai.org
- Location: Seattle, WA
- Website: http://www.allenai.org
- Repositories: 454
- Profile: https://github.com/allenai
Citation (CITATION.cff)
# YAML 1.2
---
cff-version: "1.2.0"
title: "Staged Training for Transformer Language Models"
license: "Apache-2.0"
message: "If you use staged training in your research or wish to refer to the baseline results published here, please cite using this metadata."
repository-code: "https://github.com/allenai/staged-training"
authors:
- affiliation: "University of California, Berkeley"
family-names: Shen
given-names: Sheng
- affiliation: "Allen Institute for Artificial Intelligence"
family-names: Walsh
given-names: Pete
- affiliation: "University of California, Berkeley"
family-names: Keutzer
given-names: Kurt
- affiliation: "Allen Institute for Artificial Intelligence"
family-names: Dodge
given-names: Jesse
- affiliation: "Allen Institute for Artificial Intelligence"
family-names: Peters
given-names: Matthew
- affiliation: "Allen Institute for Artificial Intelligence"
family-names: Beltagy
given-names: Iz
preferred-citation:
type: "article"
title: "Staged Training for Transformer Language Models"
doi: "10.48550/arXiv.2203.06211"
url: "https://arxiv.org/abs/2203.06211"
year: 2022
authors:
- affiliation: "University of California, Berkeley"
family-names: Shen
given-names: Sheng
- affiliation: "Allen Institute for Artificial Intelligence"
family-names: Walsh
given-names: Pete
- affiliation: "University of California, Berkeley"
family-names: Keutzer
given-names: Kurt
- affiliation: "Allen Institute for Artificial Intelligence"
family-names: Dodge
given-names: Jesse
- affiliation: "Allen Institute for Artificial Intelligence"
family-names: Peters
given-names: Matthew
- affiliation: "Allen Institute for Artificial Intelligence"
family-names: Beltagy
given-names: Iz
GitHub Events
Total
- Watch event: 2
Last Year
- Watch event: 2
Dependencies
- black ==21.7b0 development
- flake8 ==3.9.2 development
- mypy ==0.910 development
- pytest ==6.2.4 development
- accelerate ==0.3.0
- click ==7.1.2
- click-help-colors ==0.9.1
- datasets ==1.9.0
- more-itertools ==8.8.0
- numpy *
- pytorch-lightning ==1.3
- sentencepiece ==0.1.96
- tensorboardX *
- test-tube ==0.7.5
- tokenizers ==0.10.3
- torch ==1.9.0
- tqdm ==4.61.2
- transformers ==4.8.2