learned-flash-attention
An exploratory project on learning the softmax statistics used in FlashAttention from training data.
Science Score: 54.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
✓Academic publication links
Links to: arxiv.org -
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (8.8%) to scientific vocabulary
Repository
An exploratory project on learning the softmax statistics used in FlashAttention from training data.
Basic Info
- Host: GitHub
- Owner: lnairGT
- License: mit
- Language: Python
- Default Branch: main
- Size: 449 KB
Statistics
- Stars: 1
- Watchers: 2
- Forks: 0
- Open Issues: 0
- Releases: 0
Metadata Files
README.md
Can FlashAttention Statistics be Learned?
This is a small project that I worked on that tries to learn the statistics ($m(x)$ and $l(x)$) computed in FlashAttention. The basic idea is as follows: $softmax([A1, A2, ... AN]) = [\alpha1 * softmax(A1), \alpha2 * softmax(A2), ... \alphaN * softmax(AN)]$ Here, $\alphai \in \mathcal{R}^{S}$, where $S$ is the sequence length. This allows the softmax operation to be performed in blocks, enabling sequence parallelism -- i.e., the operation $softmax(QK^T)$ can be performed in independent blocks (hence processed in parallel). The original FlashAttention paper computes the statistics over each iteration of a for loop, as each block of Q and K is loaded. In this project, I try to learn these statistics (or scales) $\alpha1 ... \alphaN$.
NOTE: During training, all layers of the model are frozen and only the scales are learned.
Running the code
Requirements
This repo requires Pytorch and HuggingFace Transformers version 4.27.4. This code has not been tested on the more recent versions of HF Transformers.
The scripts are adapted from HuggingFace. A sample script for running OPT-350M, dividing softmax into two blocks is shown in run_opt_train.sh.
python run_clm_no_trainer.py \
--model_name_or_path lnair/opt-350m-wikitext2 \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 2 \
--output_dir opt-350m-softmax-scales \
--checkpointing_steps epoch \
--num_train_epochs 30 \
--scales_lr 1e-2 \
--num_softmax_blocks 2
Results on OPT-Models
The following PPL values are obtained on Wikitext-2, with the learned scales. The following arguments are used:
--num_softmax_blocks 2
--scales_lr 1e-2
--num_train_epochs 50
| Model | PPL (FP32) | PPL (Ours) | | -------- | ---------- | ---------- | | OPT-125M | 26 | 22.36 | | OPT-350M | 16.84 | 17.81 | | OPT-1.3B | 11.78 | 12.63 | | OPT-2.7B | 11 | 12.56 |
Visualizations
Comparisons of generated attention maps and spikiness (For reference on what "spikiness" means, please see this paper on Softmax mimicry). The plots titled "no trained statistics" refers to cases where the softmax is broken into blocks without applying any correction with the scales. When the learned scales are applied, the softmax outputs match the ground truth baseline.

Owner
- Name: Lakshmi Nair
- Login: lnairGT
- Kind: user
- Company: Georgia Institute of Technology
- Website: https://scholar.google.com/citations?user=eTGOo_cAAAAJ&hl=en
- Repositories: 2
- Profile: https://github.com/lnairGT
PhD Robotics student
Citation (citation.cff)
cff-version: 1.2.0
message: "If you find this work helpful, please consider citing it as below."
authors:
- family-names: Nair
given-names: Lakshmi
title: "Can FlashAttention Statistics be Learned?"
version: 2.0.4
date-released: 2024-04-04
url: "https://github.com/lnairGT/Learned-Flash-Attention"