rwkv-infctx-trainer
RWKV infctx trainer, for training arbitary context sizes, to 10k and beyond!
Science Score: 44.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
○Academic publication links
-
○Committers with academic emails
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (15.0%) to scientific vocabulary
Repository
RWKV infctx trainer, for training arbitary context sizes, to 10k and beyond!
Basic Info
- Host: GitHub
- Owner: RWKV
- License: apache-2.0
- Language: Jupyter Notebook
- Default Branch: main
- Size: 31.6 MB
Statistics
- Stars: 148
- Watchers: 5
- Forks: 30
- Open Issues: 13
- Releases: 10
Metadata Files
README.md
RWKV Infinite Context trainer
If you are new to RWKV, it would be better to find out more about us via our wiki first here: https://wiki.rwkv.com/
RWKV trainer with - no training context limit (via BPTT) - deepspeed 3 - HF dataset integration
With this implementation you can train on arbitrarily long context within (near) constant VRAM consumption; this increasing should be, about 2MB per 1024/2048 tokens (depending on your chosen ctx_len, with RWKV 7B as an example) in the training sample, which will enable training on sequences over 1M tokens.
The training code is by the way tremendously refactored into using PyTorch 2.0, Lightning 2.0 and DeepSpeed 2.0, and the starting script now relies on LightningCLI so you will see the config-example.yaml containing all the switches, mostly standard ones that Lightning processes by itself. And new ones for RWKV and the dataset parser.
To use this repo, go into RWKV-v4neo directory and do
sh
python3 lightning_trainer.py fit -c {your_config}.yaml
Remember to modify the configuration for your own need.
See RWKV-v4neo/config-example.yaml for documentation on the various options
NOTE: Due to current incomplete implementation, without state gradient, bptt_truncate is forced to be true
Environment setup
Note: There is a known issue with CUDA 12.0 and multi-gpu at this point of writing. Upgrade to CUDA 12.1 or 12.2 atleast Or downgrade to 11.8
The following venv setup using conda, modify for your use case respectively
```shell
ninja-build is required for the new trainer
sudo apt-get install ninja-build
Update conda & its package listings
conda update conda
Virtual env, with python 3.10
python 3.11 have issues with torch.compile / h100s
and if you want to use 3.11, you will need to do a nightly build install
conda create -n rwkv-infctx python=3.11 pip conda activate rwkv-infctx
Install pytorch (>=2.1.2)
conda install -y pytorch==2.1.2 torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia python -m pip install lightning==2.1.3 deepspeed==0.12.6
Currently for torch.compile + 3.11 to work, for some platform, you will need the nightly build
if so you may need to try the following instead - this is considered highly "unstable"
---
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch-nightly -c nvidia
python -m pip install lightning==2.0.5 deepspeed==0.10.0
Verify your pytorch version
python -c "import torch; print(torch.version)"
Install all the other various dependencies
PS: We use python -m pip, instead of pip directly, as it resolve issues with venv not loading the right pip
python -m pip install datasets transformers python -m pip install ninja numexpr jsonargparse 'jsonargparse[signatures]' python -m pip install lm-dataformat ftfy sentencepiece tokenizers wandb
Optional dependencies, useful for running notebooks, etc
python -m pip install papermill ```
Alternatively you could use the requirements.txt (this may not install pytorch-cuda properly, and is found to be not compatible with conda environments)
shell
python3 -m pip install -r requirements.txt
Due to issues with deepspeed on windows. Only linux environments are supported. WSl2 with windows is not recommended, due to heavy performance penalities in the process (cannot use deepspeed offload, ~50% slower)
Overall training process
- Either init a new model, or download an existing model
- To initialize a new model use
python3 ./init_model.py --n_layer {number-of-layers} --n_embd {embedding-size} --vocab_size {vocab-size/neox/world} --skip-if-exists ../model/file/path.pth
- To initialize a new model use
- Setup the config.yaml file, customized for your foundation model / finetune use case
- Preload the dataset using the
python3 preload_datapath.py {you-config}.yaml - Start the training process
python3 lightning_trainer.py fit -c {your_config}.yaml - Export the checkpoint after training is complete with
python3 export_checkpoint.py ../path/to/checkpoint/last.ckpt/ ../path/to/export/model.pth - optional, run the dragon prompt as a quick sanity check
python3 dragon_test.py ../path/to/export/model.pth
In summary with code, from the trainer directory (eg. RWKV-v4neo)
```shell
Initialize the blank model (or download a pretrained model)
python3 initmodel.py --nlayer {number-of-layers} --nembd {embedding-size} --vocabsize {vocab-size/neox/world} --skip-if-exists ../model/file/path.pth
Preload your dataset
python3 preload_datapath.py {you-config}.yaml
Run the training process
python3 lightningtrainer.py fit -c {yourconfig}.yaml
Export the checkpoint to model code
python3 export_checkpoint.py ../path/to/checkpoint/last.ckpt/ ../path/to/export/model.pth
Quick test the model with the dragon prompt
python3 dragon_test.py ../path/to/export/model.pth
@TODO, convert the model to bf16 format (instead of the huge fp32 format now)
for now you will have to use the RWKV pip package to do this with python code:
https://pypi.org/project/rwkv/
```
Examples of configuration files
You can find the following notebook/examples at the following ... - fully annotation of various configs at ./RWKV-v4neo/config-example.py - minimal config example at ./RWKV-v4neo/config-example.py - configuration / notebooks for various dataset usecases here - @TODO: training scenerios specific examples
For configuration issues, please review through the examples listed above first, before asking questions on discord.
You can find the training channel on our discord here: https://discord.com/channels/992359628979568762/992362252269256815
Important notes on infctx lightning trainer
- Ensure your host is not running cuda 12.0 (use either 11.8, or >=12.1), as this is known to have freeze issues
- When resuming from checkpoint, the estimated time is inaccurate. See: https://github.com/Lightning-AI/lightning/issues/18220
- Note that some terms are confusing, so this is a quick glossary
- a
stepin the progress bar below, means 1 data sample PER GPU. - a classic transformer batch is a
trainer/global_stepin wandb - a
substepin wandb means a single data sample. -(accumulate_gradiant_batch * gpu count) substeps = 1 trainer/global_step
- a
Should I use the official RWKV-LM trainer or the infctx trainer?
Generally if your training a foundation model from scratch - with a fixed context size, and you need the absolute highest throughput across multiple nodes (ie. 10 nodes filled with A100 servers), the official trainer would perform much better (ie 2x faster depending on the settings)
If you need deepspeed 3 support, or you deal with dynamic datasets, this trainer is much more flexible, for nearly all other use cases.
Overtime as we optimize the infctx trainer, the gap to the official trainer should shrink, however this is not the highest priority (infctx working > absolute speed)
Some long term architecture goals
- CUDA should be optional
- Moving forward, this allows us to potentially train (even if its at a perf cost) on other architectures like AMD ROCM, TPU, or Apple M1 architecture.
- No dependency on the official RWKV pip package
- This is an intentional choice, to help facilitate easy iteration on model architecture in
#rwkv-xdevelopment. So that the entire train-test-validation of design changes can be done in this repository.
- This is an intentional choice, to help facilitate easy iteration on model architecture in
Existing limitations
The following features are not yet supported (that may exist in blinks original repo) - numpy file dataset - model resize weights (init from smaller to bigger model) - helper script to add new tokens to existing model - torch compile is NOT supported, as this has been unstable on nightly build - LoRA is not yet supported, use https://github.com/blealtan/RWKV-LM-LoRA instead for now
Designated maintainer
@picocreator - is the current maintainer of the project, you can ping him on the RWKV discord if you have any questions on this project
Credits (for v4neo and v5 code)
- The bulk of the first infctx trainer was originally rewritten by @Blealtan at : https://github.com/Blealtan/RWKV-LM-LoRA/tree/dev-infctx
- RWKV-LM and the original trainer code is credited to @BlinkDL at : https://github.com/BlinkDL/RWKV-LM
- Special credit to @Yuzaboto and @bananaman via our RWKV discord, whose assistance was crucial to help debug and fix the repo to work with RWKVv4 and RWKVv5 code respectively.
- @picocreator for getting the project feature complete for RWKV mainline release
Special thanks
- PyTorch Lightning team @lantiga and @Adrian via Pytorch LIghtning AI discord - who assisted in clarifying questions on pytorch lightning
This project was intentionally a hard fork, as it has too many conflicting changes to the official RWKV-LM repo
Owner
- Name: RWKV
- Login: RWKV
- Kind: organization
- Location: United States of America
- Website: https://github.com/BlinkDL/RWKV-LM
- Twitter: RWKV_AI
- Repositories: 11
- Profile: https://github.com/RWKV
RWKV is an RNN with Transformer-level LLM performance. It is hosted as an incubation project in LF AI & Data Foundation.
Citation (CITATION.cff)
cff-version: 1.2.0 message: "If you use this software, please cite it as below." authors: - family-names: "PENG" given-names: "Bo" orcid: "https://orcid.org/0000-0002-0865-547X" - family-names: "Cao" given-names: "Huanqi" orcid: "https://orcid.org/0000-0002-3870-106X" - family-names: "Cheah" given-names: "Eugene" orcid: "https://orcid.org/0009-0002-8977-475X" title: "RWKV-LM infctx trainer" version: 1.0.0 date-released: 2023-08-16 url: "https://github.com/PicoCreator/RWKV-LM-LoRA"
GitHub Events
Total
- Watch event: 17
- Fork event: 2
Last Year
- Watch event: 17
- Fork event: 2
Committers
Last synced: 7 months ago
Top Committers
| Name | Commits | |
|---|---|---|
| @picocreator (Eugene Cheah) | e****n@g****m | 446 |
| PENG Bo | 3****L | 327 |
| BlinkDL | a@a****m | 131 |
| Eugene Cheah (picocreator) | p****r@g****m | 59 |
| SmerkyG | S****G | 57 |
| Nathan | me@n****k | 26 |
| Blealtan Cao | b****n@o****m | 11 |
| diannao | 5****k@o****m | 1 |
| TearGosling | t****u@g****m | 1 |
| Alexander | s****y | 1 |
| root | r****t@n****m | 1 |
Committer Domains (Top 20 + Academic)
Issues and Pull Requests
Last synced: 7 months ago
All Time
- Total issues: 23
- Total pull requests: 81
- Average time to close issues: about 1 month
- Average time to close pull requests: 3 days
- Total issue authors: 12
- Total pull request authors: 10
- Average comments per issue: 1.22
- Average comments per pull request: 0.09
- Merged pull requests: 67
- Bot issues: 0
- Bot pull requests: 0
Past Year
- Issues: 2
- Pull requests: 3
- Average time to close issues: N/A
- Average time to close pull requests: 2 minutes
- Issue authors: 2
- Pull request authors: 2
- Average comments per issue: 0.5
- Average comments per pull request: 0.0
- Merged pull requests: 0
- Bot issues: 0
- Bot pull requests: 0
Top Authors
Issue Authors
- PicoCreator (12)
- cgoxopx (1)
- yynil (1)
- shouldsee (1)
- wooks186 (1)
- diannaojiang (1)
- cahya-wirawan (1)
- h-a-s-k (1)
- petergaoshan (1)
- BrightXiaoHan (1)
- General-Redshift (1)
Pull Request Authors
- PicoCreator (75)
- m8than (12)
- SmerkyG (10)
- BabyChouSr (2)
- diannaojiang (2)
- shiroko98 (2)
- freckletonj (2)
- harrisonvanderbyl (1)
- Bruber92 (1)
- TearGosling (1)
Top Labels
Issue Labels
Pull Request Labels
Packages
- Total packages: 1
- Total downloads: unknown
- Total dependent packages: 0
- Total dependent repositories: 0
- Total versions: 10
proxy.golang.org: github.com/rwkv/rwkv-infctx-trainer
- Documentation: https://pkg.go.dev/github.com/rwkv/rwkv-infctx-trainer#section-documentation
- License: apache-2.0
-
Latest release: v2.3.0+incompatible
published about 2 years ago
Rankings
Dependencies
- actions/checkout v3 composite
- docker/build-push-action v4 composite
- docker/login-action 28218f9b04b4f3f62068d7b6ce6ca5b26e35336c composite
- docker/metadata-action 98669ae865ea3cffbcbaa878cf57c20bbf1c6c38 composite
- docker/setup-buildx-action v2 composite
- sigstore/cosign-installer f3c664df7af409cb4873aa5068053ba9d61a57b6 composite
- nvidia/cuda 11.8.0-cudnn8-devel-ubuntu22.04 build
- datasets ==2.13.1
- deepspeed ==0.10.0
- ftfy ==6.1.1
- jsonargparse ==4.22.1
- lightning ==2.0.5
- lm-dataformat ==0.0.20
- ninja ==1.11.1
- numexpr ==2.8.4
- sentencepiece ==0.1.99
- tokenizers ==0.13.3
- transformers ==4.30.2
- wandb ==0.15.5
- datasets ==2.13.1
- deepspeed ==0.10.0
- ftfy ==6.1.1
- jsonargparse ==4.22.1
- lightning ==2.0.5
- lm-dataformat ==0.0.20
- ninja ==1.11.1
- numexpr ==2.8.4
- sentencepiece ==0.1.99
- tokenizers ==0.13.3
- torch ==2.0.1
- torchaudio *
- torchvision *
- transformers ==4.30.2
- wandb ==0.15.5
- actions/checkout v3 composite
- actions/upload-artifact v3 composite
- actions/checkout v3 composite
- actions/upload-artifact v3 composite
- actions/checkout v3 composite
- actions/upload-artifact v3 composite
- ghcr.io/picocreator/rwkv-lm-lora env-cuda-11-8 build