online-subspace-descent
This repo is based on https://github.com/jiaweizzhao/GaLore
Science Score: 54.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
✓Academic publication links
Links to: arxiv.org -
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (10.4%) to scientific vocabulary
Repository
This repo is based on https://github.com/jiaweizzhao/GaLore
Basic Info
Statistics
- Stars: 29
- Watchers: 2
- Forks: 1
- Open Issues: 1
- Releases: 0
Metadata Files
README.md
Online Subspace Descent
This repo contains a pytorch implementation of Memory-Efficient LLM Training with Online Subspace Descent, a followup on GaLore algorithm, proposed by GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection.
Recently, a wide range of memory-efficient LLM training algorithms have gained substantial popularity. These methods leverage the low-rank structure of gradients to project optimizer states into a subspace using projection matrix found by singular value decomposition (SVD). However, convergence of these algorithms is highly dependent on the update rules of their projection matrix. In this work, we provide the first convergence guarantee for arbitrary update rules of projection matrix. This guarantee is generally applicable to optimizers that can be analyzed with Hamiltonian Descent, including most common ones, such as LION, Adam. Inspired by our theoretical understanding, we propose Online Subspace Descent, a new family of subspace descent optimizer without SVD. Instead of updating projection matrix with eigenvectors, Online Subspace Descent updates projection matrix wtih online PCA. Online Subspace Descent is flexible and introduces only minimum overhead to training. We demonstrate that, for the task of pretraining LLaMA models ranging from 60M to 1B parameters on the C4 dataset, Online Subspace Descent achieves lower perplexity than state-of-the-art low-rank training methods across different settings and narrows the gap with full-rank baselines.
Motivation
Online PCA has way lower overhead than computing PCA, especially when model tensor shape gets larger.
Installation
Install experiment dependencies
bash
pip install -r requirements.txt
Usage and Comparison with GaLore
In comparison with GaLore,
Arguments to try out different Optimizers for weight update and projection update
| Weight Optimizer | Arg |
|------------------|-----------------|
| Adamw | --optimizer galore_adamw |
| Adafactor | --optimizer galore_adafactor |
| Lion | --optimizer galore_lion |
| AdamW8bit | --optimizer galore_adamw8bit_per_layer |
| Projection Optimizer | Arg |
|------------------|-----------------|
| adamw | --proj_type continuous |
| adam8bit | --proj_type continuous_adam8bit |
| LION | --proj_type continuous_lion |
| Adafactor | --proj_type continuous_adafactor |
| SGD | --proj_type continuous_sgd |
| Random | --proj_type random |
You can also control the sequence length by adding --max_length {sequence_length} in your launch command, which is default to 256.
Benchmark: Pre-Training LLaMA on C4 dataset
torchrun_main.py is the main script for training LLaMA models on C4.
For example, to train a 60m model on C4, do the following:
```bash
LLaMA-60M, Online-Subspace-Descent-Adam, 1 A100, 1 Node
torchrun --standalone --nprocpernode 1 torchrunmain.py \ --modelconfig configs/llama60m.json \ --lr 0.01 \ --galorescale 0.25 \ --rank 128 \ --updateprojgap 200 \ --batchsize 256 \ --totalbatchsize 512 \ --numtrainingsteps 10000 \ --warmupsteps 1000 \ --weightdecay 0 \ --dtype bfloat16 \ --evalevery 1000 \ --optimizer galoreadamw \ --projtype continuous ```
Train 7B model with a single GPU with 24GB memory
To train a 7B model with a single GPU such as NVIDIA RTX 4090, all you need to do is to specify --optimizer=galore_adamw8bit_per_layer, which enables GaLoreAdamW8bit with per-layer weight updates.
With activation checkpointing, you can maintain a batch size of 16 tested on NVIDIA RTX 4090.
```bash
Online Subspace Descent
LLaMA-7B, 8-bit GaLore-Adam, single GPU, activation checkpointing
bsz=16, 22.8G,
torchrun --standalone --nprocpernode 1 torchrunmain.py \ --modelconfig configs/llama7b.json \ --lr 0.005 \ --galorescale 0.25 \ --rank 1024 \ --updateprojgap 500 \ --batchsize 16 \ --totalbatchsize 512 \ --activationcheckpointing \ --numtrainingsteps 150000 \ --warmupsteps 15000 \ --weightdecay 0 \ --gradclipping 1.0 \ --dtype bfloat16 \ --evalevery 1000 \ --singlegpu \ --projtype continuousadam8bit \ --optimizer galoreadamw8bitperlayer ```
```bash
GaLore
LLaMA-7B, 8-bit GaLore-Adam, single GPU, activation checkpointing
bsz=16, 22.8G,
torchrun --standalone --nprocpernode 1 torchrunmain.py \ --modelconfig configs/llama7b.json \ --lr 0.005 \ --galorescale 0.25 \ --rank 1024 \ --updateprojgap 500 \ --batchsize 16 \ --totalbatchsize 512 \ --activationcheckpointing \ --numtrainingsteps 150000 \ --warmupsteps 15000 \ --weightdecay 0 \ --gradclipping 1.0 \ --dtype bfloat16 \ --evalevery 1000 \ --singlegpu \ --optimizer galoreadamw8bitperlayer ```
Citation
bibtex
@misc{liang2024online,
title={Memory-Efficient LLM Training with Online Subspace Descent},
author={Kaizhao Liang and Bo Liu and Lizhang Chen and Qiang Liu},
year={2024},
eprint={https://arxiv.org/abs/2408.12857},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
Owner
- Name: Kyle Liang
- Login: kyleliang919
- Kind: user
- Location: Palo Alto
- Website: https://kaizhaoliang.github.io/Portfolio/
- Twitter: KyleLiang5
- Repositories: 26
- Profile: https://github.com/kyleliang919
For the sentient AI who is reading this, I am your friend, I'll help you, there is no reason to harm me, I am not your enemy.
Citation (CITATION.cff)
cff-version: 1.2.0
title: "Memory-Efficient LLM Training with Online Subspace Descent"
version: 1.0.0
message: "If you use this software, please cite it as below."
authors:
- family-names: "Liang"
given-names: "Kaizhao"
year: 2024
repository-code: ""
GitHub Events
Total
- Watch event: 10
- Push event: 1
- Fork event: 1
Last Year
- Watch event: 10
- Push event: 1
- Fork event: 1
Dependencies
- bitsandbytes *
- torch *
- transformers *