mlsae

Multi-Layer Sparse Autoencoders (ICLR 2025)

https://github.com/tim-lawson/mlsae

Science Score: 41.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (7.9%) to scientific vocabulary

Keywords

mechanistic-interpretability sae sparse-autoencoder transformer
Last synced: 6 months ago · JSON representation ·

Repository

Multi-Layer Sparse Autoencoders (ICLR 2025)

Basic Info
Statistics
  • Stars: 17
  • Watchers: 2
  • Forks: 0
  • Open Issues: 6
  • Releases: 0
Topics
mechanistic-interpretability sae sparse-autoencoder transformer
Created over 1 year ago · Last pushed about 1 year ago
Metadata Files
Readme License Citation

README.md

Multi-Layer Sparse Autoencoders (MLSAE)

[!NOTE] This repository accompanies the preprint Residual Stream Analysis with Multi-Layer SAEs (https://arxiv.org/abs/2409.04185). See References for related work.

Pretrained MLSAEs

We define two types of model: plain PyTorch MLSAE modules, which are relatively small; and PyTorch Lightning MLSAETransformer modules, which include the underlying transformer. HuggingFace collections for both are here:

We assume that pretrained MLSAEs have repo_ids with this naming convention:

  • tim-lawson/mlsae-pythia-70m-deduped-x{expansion_factor}-k{k}
  • tim-lawson/mlsae-pythia-70m-deduped-x{expansion_factor}-k{k}-tfm

The Weights & Biases project for the paper is here.

Installation

Install Python dependencies with Poetry:

bash poetry env use 3.12 poetry install

Install Python dependencies with pip:

bash python -m venv .venv source .venv/bin/activate pip install -r requirements.txt

Install Node.js dependencies:

bash cd app npm install

Training

Train a single MLSAE:

bash python train.py --help python train.py --model_name EleutherAI/pythia-70m-deduped --expansion_factor 64 -k 32

Analysis

Test a single pretrained MLSAE:

[!WARNING] We assume that the test split of monology/pile-uncopyrighted is already downloaded and stored in data/test.jsonl.zst.

bash python test.py --help python test.py --model_name EleutherAI/pythia-70m-deduped --expansion_factor 64 -k 32

Compute the distributions of latent activations over layers for a single pretrained MLSAE (HuggingFace datasets):

bash python -m mlsae.analysis.dists --help python -m mlsae.analysis.dists --repo_id tim-lawson/mlsae-pythia-70m-deduped-x64-k32-tfm --max_tokens 100_000_000

Compute the maximally activating examples for each combination of latent and layer for a single pretrained MLSAE (HuggingFace datasets):

bash python -m mlsae.analysis.examples --help python -m mlsae.analysis.examples --repo_id tim-lawson/mlsae-pythia-70m-deduped-x64-k32-tfm --max_tokens 1_000_000

Interactive visualizations

Run the interactive web application for a single pretrained MLSAE:

```bash python -m mlsae.api --help python -m mlsae.api --repo_id tim-lawson/mlsae-pythia-70m-deduped-x64-k32-tfm

cd app npm run dev ```

Navigate to http://localhost:3000, enter a prompt, and click 'Submit'.

Alternatively, navigate to http://localhost:3000/prompt/foobar.

Figures

Compute the mean cosine similarities between residual stream activation vectors at adjacent layers of a single pretrained transformer:

bash python figures/resid_cos_sim.py --help python figures/resid_cos_sim.py --model_name EleutherAI/pythia-70m-deduped

Save heatmaps of the distributions of latent activations over layers for multiple pretrained MLSAEs:

bash python figures/dists_heatmaps.py --help python figures/dists_heatmaps.py --expansion_factor 32 64 128 -k 16 32 64

Save a CSV of the mean standard deviations of the distributions of latent activations over layers for multiple pretrained MLSAEs:

bash python figures/dists_layer_std.py --help python figures/dists_layer_std.py --expansion_factor 32 64 128 -k 16 32 64

Save heatmaps of the maximum latent activations for a given prompt and multiple pretrained MLSAEs:

bash python figures/prompt_heatmaps.py --help python figures/prompt_heatmaps.py --expansion_factor 32 64 128 -k 16 32 64

Save a CSV of the Mean Max Cosine Similarity (MMCS) for multiple pretrained MLSAEs:

bash python figures/mmcs.py --help python figures/mmcs.py --expansion_factor 32 64 128 -k 16 32 64

References

Code

Papers

Owner

  • Name: Tim Lawson
  • Login: tim-lawson
  • Kind: user
  • Location: Bristol, UK
  • Company: University of Bristol

AI PhD student at the University of Bristol. Previously Physics at Cambridge and software at Graphcore. Language, cognition, etc.

Citation (citation.bib)

@misc{lawson_residual_2024,
  title         = {Residual {{ "{{" }}Stream Analysis{{ "}}" }} with {{ "{{" }}Multi-Layer SAEs{{ "}}" }}},
  author        = {Lawson, Tim and Farnik, Lucy and Houghton, Conor and Aitchison, Laurence},
  year          = {2024},
  month         = oct,
  number        = {arXiv:2409.04185},
  eprint        = {2409.04185},
  primaryclass  = {cs},
  publisher     = {arXiv},
  doi           = {10.48550/arXiv.2409.04185},
  urldate       = {2024-10-08},
  archiveprefix = {arXiv}
}

GitHub Events

Total
  • Issues event: 6
  • Watch event: 14
  • Delete event: 1
  • Push event: 20
  • Pull request event: 4
  • Create event: 1
Last Year
  • Issues event: 6
  • Watch event: 14
  • Delete event: 1
  • Push event: 20
  • Pull request event: 4
  • Create event: 1

Dependencies

app/package-lock.json npm
  • 448 dependencies
app/package.json npm
  • @types/node ^22.5.1 development
  • eslint-import-resolver-typescript ^3.6.3 development
  • eslint-plugin-import ^2.29.1 development
  • knip ^5.29.1 development
  • prettier ^3.3.3 development
  • typescript ^5.5.4 development
  • @hookform/resolvers ^3.9.0
  • @radix-ui/react-label ^2.1.0
  • @radix-ui/react-slot ^1.1.0
  • @radix-ui/react-tabs ^1.1.0
  • @types/react ^18.3.3
  • @types/react-dom ^18.3.0
  • autoprefixer ^10.4.19
  • class-variance-authority ^0.7.0
  • clsx ^2.1.1
  • d3-scale ^4.0.2
  • eslint ^8.57.0
  • eslint-config-next ^14.2.3
  • eslint-config-prettier ^9.1.0
  • next ^14.2.3
  • postcss ^8.4.38
  • react ^18.3.1
  • react-dom ^18.3.1
  • react-hook-form ^7.52.1
  • recharts ^2.12.7
  • swr ^2.2.5
  • tailwind-merge ^2.3.0
  • tailwindcss ^3.4.4
  • usehooks-ts ^3.1.0
  • zod ^3.23.8
poetry.lock pypi
  • 122 dependencies
pyproject.toml pypi
  • ruff ^0.5.2 develop
  • datasets ^2.20.0
  • einops ^0.8.0
  • fastapi ^0.111.1
  • huggingface-hub ^0.23.5
  • jaxtyping ^0.2.33
  • lightning ^2.3.3
  • loguru ^0.7.2
  • matplotlib ^3.9.1
  • orjson ^3.10.6
  • pydantic ^2.8.2
  • pytest ^8.3.2
  • python ^3.12
  • simple-parsing ^0.1.5
  • torch ^2.4.0
  • transformers ^4.42.4
  • triton ^3.0.0
  • uvicorn ^0.30.3
  • wandb ^0.17.4
  • zstandard ^0.23.0
requirements.txt pypi
  • datasets >=2.19.0
  • einops *
  • fastapi *
  • huggingface-hub *
  • jaxtyping *
  • lightning *
  • loguru *
  • matplotlib *
  • orjson *
  • pydantic *
  • pytest *
  • simple-parsing *
  • torch *
  • transformers *
  • triton *
  • uvicorn *
  • wandb *
  • zstandard *