Science Score: 54.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
✓Academic publication links
Links to: arxiv.org -
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (11.3%) to scientific vocabulary
Repository
Basic Info
- Host: GitHub
- Owner: sustcsonglin
- License: mit
- Language: Python
- Default Branch: main
- Size: 2.09 MB
Statistics
- Stars: 7
- Watchers: 2
- Forks: 1
- Open Issues: 0
- Releases: 0
Metadata Files
README.md
This repo aims at providing a collection of efficient Triton-based implementations for state-of-the-art linear attention models. Any pull requests are welcome!
News
- $\texttt{[2024-12]}$: :tada: Add Gated DeltaNet implementation to
fla(paper). - $\texttt{[2024-12]}$: :rocket:
flanow officially supports kernels with variable-length inputs. - $\texttt{[2024-11]}$: The inputs are now switched from head-first to seq-first format.
- $\texttt{[2024-11]}$: :boom:
flanow provides a flexible way for training hybrid models. - $\texttt{[2024-10]}$: :fire: Announcing
flame, a minimal and scalable framework for trainingflamodels. Check out the details here. - $\texttt{[2024-09]}$:
flanow includes a fused linear and cross-entropy layer, significantly reducing memory usage during training. - $\texttt{[2024-09]}$: :tada: Add GSA implementation to
fla(paper). - $\texttt{[2024-05]}$: :tada: Add DeltaNet implementation to
fla(paper). - $\texttt{[2024-05]}$: :boom:
flav0.1: a variety of subquadratic kernels/layers/models integrated (RetNet/GLA/Mamba/HGRN/HGRN2/RWKV6, etc., see Models). - $\texttt{[2023-12]}$: :boom: Launched
fla, offering a collection of implementations for state-of-the-art linear attention models.
Models
Roughly sorted according to the timeline supported in fla
|Year | Venue | Model | Title | Paper | Code | fla impl |
|:--- |:----------------------- | :------------- | :-------------------------------------------------------------------------------------------------------- | :----------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------: |
|2023| | RetNet | Retentive network: a successor to transformer for large language models | link | official | code |
|2024| ICML | GLA | Gated Linear Attention Transformers with Hardware-Efficient Training | link | official | code |
|2024| ICML |Based | Simple linear attention language models balance the recall-throughput tradeoff | link | official | code |
| 2024| ACL| Rebased | Linear Transformers with Learnable Kernel Functions are Better In-Context Models | link | official | code |
|2024| NeurIPS | DeltaNet | Parallelizing Linear Transformers with Delta Rule over Sequence Length | link | official | code |
|2022| ACL | ABC | Attention with Bounded-memory Control | link | | code |
|2024| NeurIPS | HGRN | Hierarchically Gated Recurrent Neural Network for Sequence Modeling | link | official | code |
|2024| COLM | HGRN2 | HGRN2: Gated Linear RNNs with State Expansion | link | official | code |
|2024| COLM | RWKV6 | Eagle and Finch: RWKV with Matrix-Valued States and Dynamic Recurrence | link | official | code |
|2024| | Samba | Samba: Simple Hybrid State Space Models for Efficient Unlimited Context Language Modeling | link | official | code |
|2024 | ICML | Mamba2 | Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality | link | official | code |
|2024 | NeurIPS |GSA | Gated Slot Attention for Efficient Linear-Time Sequence Modeling | link | official | code |
|2024 | | Gated DeltaNet | Gated Delta Networks: Improving Mamba2 with Delta Rule | link | official | code |
Installation
The following requirements should be satisfied - PyTorch >= 2.0 - Triton >=2.2 - einops
As fla is actively developed now, no released packages are provided at this time.
If you do need to use fla ops/modules and contemplate further explorations, an alternative way is to install the package from source
sh
pip install -U git+https://github.com/sustcsonglin/flash-linear-attention
or manage fla with submodules
sh
git submodule add https://github.com/sustcsonglin/flash-linear-attention.git 3rdparty/flash-linear-attention
ln -s 3rdparty/flash-linear-attention/fla fla
[!CAUTION] If you're not working with Triton v2.2 or its nightly release, it's important to be aware of potential issues with the
FusedChunkimplementation, detailed in this issue. You can run the testpython tests/test_fused_chunk.pyto check if your version is affected by similar compiler problems. While we offer some fixes for Triton<=2.1, be aware that these may result in reduced performance.For both Triton 2.2 and earlier versions (up to 2.1), you can reliably use the
Chunkversion (with hidden states materialized into HBMs). After careful optimization, this version generally delivers high performance in most scenarios.
Usage
Token Mixing
We provide token mixing'' linear attention layers in `fla.layers` for you to use.
You can replace the standard multihead attention layer in your model with other linear attention layers.
Example usage is as follows:
`py
import torch from fla.layers import MultiScaleRetention batchsize, numheads, seqlen, hiddensize = 32, 4, 2048, 1024 device, dtype = 'cuda:0', torch.bfloat16 retnet = MultiScaleRetention(hiddensize=hiddensize, numheads=numheads).to(device=device, dtype=dtype) retnet MultiScaleRetention( (qproj): Linear(infeatures=1024, outfeatures=1024, bias=False) (kproj): Linear(infeatures=1024, outfeatures=1024, bias=False) (vproj): Linear(infeatures=1024, outfeatures=2048, bias=False) (gproj): Linear(infeatures=1024, outfeatures=2048, bias=False) (oproj): Linear(infeatures=2048, outfeatures=1024, bias=False) (gnormswishgate): FusedRMSNormSwishGate(512, eps=1e-05) (rotary): RotaryEmbedding() ) x = torch.randn(batchsize, seqlen, hiddensize).to(device=device, dtype=dtype) y, * = retnet(x) y.shape torch.Size([32, 2048, 1024]) ```
We provide the implementations of models that are compatible with 🤗 Transformers library.
Here's an example of how to initialize a GLA model from the default configs in fla:
```py
from fla.models import GLAConfig from transformers import AutoModelForCausalLM config = GLAConfig() config GLAConfig { "attn": null, "attnmode": "chunk", "bostokenid": 1, "clampmin": null, "convsize": 4, "elementwiseaffine": true, "eostokenid": 2, "expandk": 0.5, "expandv": 1, "featuremap": null, "fusecrossentropy": true, "fusenorm": true, "hiddenact": "swish", "hiddenratio": 4, "hiddensize": 2048, "initializerrange": 0.02, "intermediatesize": null, "maxpositionembeddings": 2048, "modeltype": "gla", "normeps": 1e-06, "numheads": 4, "numhiddenlayers": 24, "numkvheads": null, "tiewordembeddings": false, "transformersversion": "4.45.0", "usecache": true, "usegk": true, "usegv": false, "useoutputgate": true, "useshortconv": false, "vocab_size": 32000 }
AutoModelForCausalLM.fromconfig(config) GLAForCausalLM( (model): GLAModel( (embeddings): Embedding(32000, 2048) (layers): ModuleList( (0-23): 24 x GLABlock( (attnnorm): RMSNorm(2048, eps=1e-06) (attn): GatedLinearAttention( (qproj): Linear(infeatures=2048, outfeatures=1024, bias=False) (kproj): Linear(infeatures=2048, outfeatures=1024, bias=False) (vproj): Linear(infeatures=2048, outfeatures=2048, bias=False) (gproj): Linear(infeatures=2048, outfeatures=2048, bias=False) (gkproj): Sequential( (0): Linear(infeatures=2048, outfeatures=16, bias=False) (1): Linear(infeatures=16, outfeatures=1024, bias=True) ) (oproj): Linear(infeatures=2048, outfeatures=2048, bias=False) (gnormswishgate): FusedRMSNormSwishGate(512, eps=1e-06) ) (mlpnorm): RMSNorm(2048, eps=1e-06) (mlp): GLAMLP( (gateproj): Linear(infeatures=2048, outfeatures=11264, bias=False) (downproj): Linear(infeatures=5632, outfeatures=2048, bias=False) (actfn): SiLU() ) ) ) (norm): RMSNorm(2048, eps=1e-06) ) (lmhead): Linear(infeatures=2048, outfeatures=32000, bias=False) )
```
Fused Modules
We offer a collection of fused modules in fla.modules to facilitate faster training:
Rotary Embedding: rotary positional embeddings as adopted by the Llama architecture, a.k.a., Transformer++.Norm Layers:RMSNorm,LayerNormandGroupNormRMSNormLinear,LayerNormLinearandGroupNormLinearto reduce memory usage of intermediate tensors for improved memory efficiency.
Norm Layers with Gating: combine norm layers with element-wise gating, as used by RetNet/GLA.Cross Entropy: faster Triton implementation of cross entropy loss.Linear Cross Entropy: fused linear layer and cross entropy loss to avoid the materialization of large logits tensors. Also refer to implementations by mgmalek and Liger-Kernel.Linear KL Divergence: fused linear layer and KL divergence loss in a similar vein as CE loss.
Generation
Upon successfully pretraining a model, it becomes accessible for generating text using the 🤗 text generation APIs. In the following, we give a generation example: ```py
import fla from transformers import AutoModelForCausalLM, AutoTokenizer name = 'fla-hub/gla-1.3B-100B' tokenizer = AutoTokenizer.frompretrained(name) model = AutoModelForCausalLM.frompretrained(name).cuda() inputprompt = "Power goes with permanence. Impermanence is impotence. And rotation is castration." inputids = tokenizer(inputprompt, returntensors="pt").inputids.cuda() outputs = model.generate(inputids, maxlength=64) tokenizer.batchdecode(outputs, skipspecialtokens=True)[0] ```
We also provide a simple script here for benchmarking the generation speed. Simply run it by: ```sh $ python -m benchmarks.benchmarkgeneration \ --path 'fla-hub/gla-1.3B-100B' \ --repetitionpenalty 2. \ --prompt="Hello everyone, I'm Songlin Yang"
Prompt: Hello everyone, I'm Songlin Yang Generated: Hello everyone, I'm Songlin Yang. I am a 20 year old girl from China who is currently studying in the United States of America for my Master degree and also working as an English teacher at school here on campus since last summer (1st semester). My main goal to be able do well with this course so that we can have
Prompt length: 10, generation length: 64 Total prompt processing + decoding time: 4593ms ```
All of the pretrained models currently available can be found in fla-hub.
```py
from huggingfacehub import listmodels for model in list_models(author='fla-hub'): print(model.id) ```
Hybrid Models
fla provides a flexible method to incorporate standard attention layers into existing linear attention models.
This is easily achieved by specifying the attn argument in the model configuration.
For example, to create a 2-layer Samba model with interleaved Mamba and local attention layers, using a sliding window size of 2048:
```py
from fla.models import SambaConfig from transformers import AutoModelForCausalLM config = SambaConfig(numhiddenlayers=2) config.attn = { 'layers': [1], 'numheads': 18, 'numkvheads': 18, 'windowsize': 2048 } config SambaConfig { "attn": { "layers": [ 1 ], "numheads": 18, "numkvheads": 18, "windowsize": 2048 }, "bostokenid": 1, "convkernel": 4, "eostokenid": 2, "expand": 2, "fusecrossentropy": true, "fusenorm": true, "hiddenact": "silu", "hiddenratio": 4, "hiddensize": 2304, "initializerrange": 0.02, "intermediatesize": 4608, "maxpositionembeddings": 2048, "modeltype": "samba", "normeps": 1e-05, "numhiddenlayers": 2, "padtokenid": 0, "rescaleprenormresidual": false, "residualinfp32": false, "statesize": 16, "tiewordembeddings": false, "timestepfloor": 0.0001, "timestepinitscheme": "random", "timestepmax": 0.1, "timestepmin": 0.001, "timesteprank": 144, "timestepscale": 1.0, "transformersversion": "4.45.0", "usebias": false, "usecache": true, "useconvbias": true, "vocab_size": 32000 }
AutoModelForCausalLM.fromconfig(config) SambaForCausalLM( (backbone): SambaModel( (embeddings): Embedding(32000, 2304) (layers): ModuleList( (0): SambaBlock( (mixernorm): RMSNorm(2304, eps=1e-05) (mixer): MambaMixer( (conv1d): Conv1d(4608, 4608, kernelsize=(4,), stride=(1,), padding=(3,), groups=4608) (act): SiLU() (inproj): Linear(infeatures=2304, outfeatures=9216, bias=False) (xproj): Linear(infeatures=4608, outfeatures=176, bias=False) (dtproj): Linear(infeatures=144, outfeatures=4608, bias=True) (outproj): Linear(infeatures=4608, outfeatures=2304, bias=False) ) (mlpnorm): RMSNorm(2304, eps=1e-05) (mlp): SambaMLP( (gateproj): Linear(infeatures=2304, outfeatures=12288, bias=False) (downproj): Linear(infeatures=6144, outfeatures=2304, bias=False) (actfn): SiLU() ) ) (1): SambaBlock( (mixernorm): RMSNorm(2304, eps=1e-05) (mixer): Attention( (qproj): Linear(infeatures=2304, outfeatures=2304, bias=False) (kproj): Linear(infeatures=2304, outfeatures=2304, bias=False) (vproj): Linear(infeatures=2304, outfeatures=2304, bias=False) (oproj): Linear(infeatures=2304, outfeatures=2304, bias=False) (rotary): RotaryEmbedding() ) (mlpnorm): RMSNorm(2304, eps=1e-05) (mlp): SambaMLP( (gateproj): Linear(infeatures=2304, outfeatures=12288, bias=False) (downproj): Linear(infeatures=6144, outfeatures=2304, bias=False) (actfn): SiLU() ) ) ) (normf): RMSNorm(2304, eps=1e-05) ) (lmhead): Linear(infeatures=2304, outfeatures=32000, bias=False) ) ```
During inference, you DO NOT need to revise anything for generation! The model will produce output as-is, without any need for additional configurations or modifications.
Evaluations
The lm-evaluation-harness library allows you to easily perform (zero-shot) model evaluations. Follow the steps below to use this library:
Install
lm_evalfollowing their instructions.Run evaluation with:
sh $ PATH='fla-hub/gla-1.3B-100B' $ python -m evals.harness --model hf \ --model_args pretrained=$PATH,dtype=bfloat16 \ --tasks wikitext,lambada_openai,piqa,hellaswag,winogrande,arc_easy,arc_challenge,boolq,sciq,copa,openbookqa \ --batch_size 64 \ --num_fewshot 0 \ --device cuda \ --show_config
We've made fla compatible with hf-style evaluations, you can call evals.harness to finish the evaluations.
Running the command above will provide the task results reported in the GLA paper.
[!Tip] If you are using
lm-evaluation-harnessas an external library and can't find (almost) any tasks available, before callinglm_eval.evaluate()orlm_eval.simple_evaluate(), simply run the following to load the library's stock tasks! ```pyfrom lmeval.tasks import TaskManager; TaskManager().initializetasks() ```
Benchmarks
We compared our Triton-based RetNet implementation with CUDA-based FlashAttention2, using a batch size of 8, 32 heads, and a head dimension of 128, across different sequence lengths. These tests were conducted on a single A100 80GB GPU, as illustrated in the following graph ```py
you might have to first install fla to enable its import via pip install -e .
$ python benchmarkretention.py Performance: seqlen fusedchunkfwd chunkfwd parallelfwd fusedchunkfwdbwd chunkfwdbwd parallelfwdbwd flashfwd flashfwdbwd 0 128.0 0.093184 0.185344 0.067584 1.009664 1.591296 1.044480 0.041984 0.282624 1 256.0 0.165888 0.219136 0.126976 1.024000 1.596928 1.073152 0.074752 0.413696 2 512.0 0.308224 0.397312 0.265216 1.550336 1.603584 1.301504 0.156672 0.883712 3 1024.0 0.603136 0.747520 0.706560 3.044864 3.089408 3.529728 0.467968 2.342912 4 2048.0 1.191424 1.403904 2.141184 6.010880 6.059008 11.009024 1.612800 7.135232 5 4096.0 2.377728 2.755072 7.392256 11.932672 11.938816 37.792770 5.997568 24.435200 6 8192.0 4.750336 5.491712 26.402817 23.759359 23.952385 141.014023 22.682114 90.619904 7 16384.0 9.591296 10.870784 101.262337 47.666176 48.745472 539.853821 91.346947 346.318848 ```
Citation
If you find this repo useful, please consider citing our works: ```bib @software{yang2024fla, title = {FLA: A Triton-Based Library for Hardware-Efficient Implementations of Linear Attention Mechanism}, author = {Yang, Songlin and Zhang, Yu}, url = {https://github.com/sustcsonglin/flash-linear-attention}, month = jan, year = {2024} }
@misc{yang2024gated, title = {Gated Delta Networks: Improving Mamba2 with Delta Rule}, author = {Songlin Yang and Jan Kautz and Ali Hatamizadeh}, year = {2024}, eprint = {2412.06464}, archivePrefix = {arXiv}, primaryClass = {cs.CL} }
@inproceedings{yang2024parallelizing, title = {Parallelizing Linear Transformers with the Delta Rule over Sequence Length}, author = {Yang, Songlin and Wang, Bailin and Zhang, Yu and Shen, Yikang and Kim, Yoon}, booktitle = {Proceedings of NeurIPS}, year = {2024} }
@inproceedings{zhang2024gsa, title = {Gated Slot Attention for Efficient Linear-Time Sequence Modeling}, author = {Zhang, Yu and Yang, Songlin and Zhu, Ruijie and Zhang, Yue and Cui, Leyang and Wang, Yiqiao and Wang, Bolun and Shi, Freda and Wang, Bailin and Bi, Wei and Zhou, Peng and Fu, Guohong}, booktitle = {Proceedings of NeurIPS}, year = {2024} }
@inproceedings{yang2024gla, title = {Gated Linear Attention Transformers with Hardware-Efficient Training}, author = {Yang, Songlin and Wang, Bailin and Shen, Yikang and Panda, Rameswar and Kim, Yoon}, booktitle = {Proceedings of ICML}, year = {2024} } ```
Owner
- Name: Songlin Yang
- Login: sustcsonglin
- Kind: user
- Location: Cambridge
- Company: MIT
- Website: https://sustcsonglin.github.io/
- Twitter: SonglinYang4
- Repositories: 63
- Profile: https://github.com/sustcsonglin
ML & NLP Research. PhD student @ MIT CSAIL
Citation (CITATION.cff)
cff-version: 1.2.0 message: "If you use this software, please cite it as below." authors: - family-names: "Yang" given-names: "Songlin" orcid: "https://orcid.org/0000-0002-5944-0110" - family-names: "Zhang" given-names: "Yu" orcid: "https://orcid.org/0000-0002-8345-3835" title: "FLA: A Triton-Based Library for Hardware-Efficient Implementations of Linear Attention Mechanism" version: 0.0.1 date-released: 2024-01-18 url: "https://github.com/sustcsonglin/flash-linear-attention"
GitHub Events
Total
- Watch event: 9
- Push event: 6
- Fork event: 1
- Create event: 2
Last Year
- Watch event: 9
- Push event: 6
- Fork event: 1
- Create event: 2
Dependencies
- actions/stale v9.0.0 composite
- actions/checkout v2 composite
- actions/setup-python v2 composite
- actions/checkout v2 composite
- actions/setup-python v2 composite
- datasets *
- einops *
- ninja *
- transformers *
- triton >=2.2