arrakis-mi

Arrakis is a library to conduct, track and visualize mechanistic interpretability experiments.

https://github.com/yash-srivastava19/arrakis

Science Score: 44.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (11.2%) to scientific vocabulary

Keywords

anthropic explainable-ai garcon mechanistic-interpretability transformer transformerlens
Last synced: 6 months ago · JSON representation ·

Repository

Arrakis is a library to conduct, track and visualize mechanistic interpretability experiments.

Basic Info
Statistics
  • Stars: 31
  • Watchers: 1
  • Forks: 3
  • Open Issues: 6
  • Releases: 0
Topics
anthropic explainable-ai garcon mechanistic-interpretability transformer transformerlens
Created over 1 year ago · Last pushed 10 months ago
Metadata Files
Readme Changelog Contributing Citation

README.md

Arrakis - A Mechanistic Interpretability Tool

Interpretability is a relatively new field where everyday something new is happening. Mechanistic Interpretability is one approach to reverse engineer neural networks and understand what is happening inside these black-box models.

Mechanistic Interpretability is a really exciting subfield of alignment, and recently, a lot has been happening in this field - especially at Anthropic. To look at the goal of MI at Anthropic, read this post. The core operation involved in MI is loading a model, looking at its weights and activations, doing some operations on them, and producing results.

I made Arrakis to deeply understand Transformer-based models (maybe in the future I will try to be model-agnostic). The first thought that should come to mind is Why not use Transformer Lens? Neel Nanda has already made significant progress in that. I made Arrakis as I wanted to have a library which can do more than just get the activations - I wanted a more complete library where researchers can do experiments, and track their progress. Think of Arrakis as a complete suite to conduct MI experiments, where I try to get the best of both Transformer Lens and Garcon. More features will be added as I understand how to make this library more useful for the community, and I need feedback for that.

Tools and Decomposibility

Regardless of what research project you are working on, if you are not keeping track of things, it gets messy really easily. In a field like MI, where you are constantly looking at all different weights and biases, and there are a lot of moving parts - it gets overwhelming fairly easily. I've experienced this personally, and being someone who is obsessed with reducing experimentation time and getting results quickly, I wanted to have a complete suite which makes my workload easy.

Arrakis is made so that this doesn't happen. The core principle behind Arrakis is decomposibility. Do all experiments with plug-and-play tools (will be much clearer in the walkthrough). This makes experimentation really flexible, and at the same time, Arrakis keeps track of different versions of the experiments by default. Everything in Arrakis is made in this plug-and-play fashion. I have even incorporated a graphing library (on top of several popular libraries) to make graphing a lot easier.

I really want feedback and contributions on this project so that this can be adapted by the community at large.

Arrakis Walkthrough

Let's understand how to conduct a small experiment in Arrakis. It is easy, reproducible and very easy to implement.

Step 1: Install the package

All the dependencies of the project are maintained through poetry. python pip install arrakis-mi

Step 2: Create HookedAutoModel

HookedAutoModel offers a convenient way to import models from Huggingface directly (with Hooks). Everything just works out of the box. First, create a HookedConfig for the model you want to support with the required parameters. Then, create a HookedAutoModel from the config. As of now, these models are supported:

python [ "gpt2", "gpt-neo", "gpt-neox", "llama", "gemma", "phi3", "qwen2", "mistral", "stable-lm", ]

As mentioned, the core idea behind Arrakis is decomposibility, so a HookedAutoModel is a wrapper around Huggingface PreTrainedModel class, with a single plug-and-play decorator for the forward pass. All the model probing happens behind the scenes, and is pre-configured.

```python from arrakis.src.corearrakis.activationcache import *

config = HookedAutoConfig(name="llama", vocabsize=50256, hiddensize=8, intermediatesize=2, numhiddenlayers=4, numattentionheads=4, numkeyvalueheads=4)

model = HookedAutoModel(config)

```

Step 3: Set up Interpretability Bench

At its core, the whole purpose of Arrakis is to conduct MI experiments. After installing, derive from the BaseInterpretabilityBench and instantiate an object (exp in this case). This object provides a lot of functionality out-of the box based on the "tool" you want to use for the experiment, and have access to the functions that the tool provides. You can also create your own tool (read about that here )

```python from arrakis.src.corearrakis.basebench import BaseInterpretabilityBench

class MIExperiment(BaseInterpretabilityBench): def init(self, model, savedir="experiments"): super().init(model, savedir) self.tools.update({"custom": CustomFunction(model)})

exp = MIExperiment(model) ```

Apart from access to MI tools, the object also provides you a convenient way to log your experiments. To log your experiments, just decorate the function you are working with @exp.log_experiment, and that is pretty much it. The function creates a local version control on the contents of the function and stores it locally. You can run many things in parallel, and the version control helps you keep track of it.

```python

Step 1: Create a function where you can do operations on the model.

@exp.logexperiment # This is pretty much it. This will log the experiment. def attentionexperiment(): print("This is a placeholder for the experiment. Use as is.") return 4

Step 2: Then, you run the function, get results. This starts the experiment.

attention_experiment()

Step 3: Then, we will look at some of the things that logs keep a track of

l = exp.listversions("attentionexperiment") # This gives the hash of the content of the experiment. print("This is the version hash of the experiment: ", l)

Step 4: You can also get the content of the experiment from the saved json.

print(exp.getversion("attentionexperiment", l[0])['source']) # This gives the content of the experiment.

`` Apart from these tools, there are also@exp.profilemodel(to profile how much resources the model is using) and@exp.testhypothesis` (to test hypotheses). Support of more tools will be added as I get more feedback from the community.

Step 4: Create your experiments

By default, Arrakis provides a lot of Anthropic's interpretability experiments (Monosemanticity, Residual Decomposition, Read Write Analysis and a lot more). These are provided as tools, so in your experiments, you can plug and play with them and conduct your experiments. Here's an example of how you can do that. ```python

Making functions for Arrakis to use is pretty easy. Let's look it in action.

Step 1: Create a function where you can do operations on the model. Think of all the tools you might need for it.

Step 2: Use the @exp.use_tools decorator on it, with additional arg of the tool.

Step 3: The extra argument gives you access to the function. Done.

@exp.usetools("writeread") # use the exp.use_tools() decorator. def readwriteanalysis(readlayeridx, writelayeridx, srcidx, writeread=None): # pass an additional argument. # Multi-hop attention (write-read)

# use the extra argument as a tool.
write_heads = write_read.identify_write_heads(read_layer_idx)  
read_heads = write_read.identify_read_heads(write_layer_idx, dim_idx=src_idx) 

return {
    "write_heads": write_heads, 
    "read_heads": read_heads
}

print(readwriteanalysis(0, 1, 0)) # Perfecto!

```

Step 5: Visualize the Results

Generating plots is Arrakis is also plug-and-play, just add the decorator and plots are generated by default. Read more about the graphing docs here ```python

from arrakis.src.graph.base_graph import *

Step 1: Create a function where you can want to draw plot.

Step2: Use the @exp.plotresults decorator on it(set the plotting lib), with additional arg of the plot spec. Pass inputids here as well(have to think on this)

Step3: The extra argument gives you access to the fig. Done.

exp.setplottinglib(MatplotlibWrapper) # Set the plotting library.

@exp.plotresults(PlotSpec(plottype = "attention", datakeys = "h.1.attn.cattn"), inputids=inputids) # use the exp.plot_results() decorator. def attention_heatmap(fig=None): # pass an additional argument. return fig

attention_heatmap() # Done. plt.show()

`` These are three upper level classes in Arrakis. One is theInterpretabilityBenchwhere you conduct experiments, the second is thecore_arrakiswhere I've implemented some common tests for Transformer based model and the third is theGraphing`.

List of Tools

There is a lot of what's happening inside the core_arrakis. There are a lot of tools that we can use, which we'll deal with one by one. We'll understand what they do and how to use Arrakis to test them. These tools are supported as of now(please contribute more!)

Go to their respective pages and read about what they mean and how to use Arrakis to conduct experiments.

Extending Arrakis

Apart from all of these tool, it is easy to develop tools on your own which you can use for your experiment. These are the steps to do so: - Step 1: Make a class which inherits from the BaseInterpretabilityTool ```python from arrakis.src.corearrakis.baseinterpret import BaseInterpretabilityTool

class CustomTool(BaseInterpretabilityTool): def init(self, model): super().init(model) self.model = model

def custom_function(self, *args, **kwargs):
    # do some computations
    pass

def another_custom_function(self, *args, **kwargs):
    # do another calcualtions
    pass 

`` The attributemodelis a wrapper around HuggingfacePreTrainedModel` with many additional features which makes easier for experimentation purposes. The reference for model is given here. Write your function that utilizes the ActivationCache and get the intermediate activations.

  • Step 2: In the derived class from BaseInterpretabilityBench, add your custom tool in the following manner.

```python from src.bench.base_bench import BaseInterpretabilityBench

Import the custom tool here.

class ExperimentBench(BaseInterpretabilityBench): def init(self, model, savedir="experiments"): super().init(model, savedir) self.tools.update({"custom": CustomTool(model)})

exp = ExperimentBench(model) # where model is an instance of HookedAutoModel ``` And that is pretty much it. Now, in order to use it in a function, just do the following:

```python

@exp.usetools("custom") def testcustomfunction(args, kwargs, custom): # the final argument should be the same name as the tool key. custom.customfunction() custom.anothercustomfunction()

testcustomfunction(args, kwargs) ``` Adding your own tool is really easy in Arrakis. Read the API reference guide to see how to implement your own functions. Open a PR for tools that are not implemented and I can add it quickly.

How to Start?

For just starting out, consider going through the files demo.ipynb to get an overview of the library, test_graphs.py and test_new_model.py to test the model and the graphs(run from the command line)

Owner

  • Name: Yash Srivastava
  • Login: yash-srivastava19
  • Kind: user
  • Location: Bhopal, Madhya Pradesh
  • Company: NIT Warangal

A Genius, Shy and Broke bloke.

Citation (CITATION.cff)

cff-version: 1.2.0
message: "If you use Arrakis in your work, please cite it as below."
authors:
- family-names: "Yashovardhan"
  given-names: "Srivastava"
title: "arrakis"
version: 1.0.0
date-released: 2024-07-26
url: "https://github.com/yash-srivastava19/arrakis"

GitHub Events

Total
  • Watch event: 12
  • Delete event: 1
  • Issue comment event: 1
  • Push event: 1
  • Pull request event: 5
  • Fork event: 1
  • Create event: 3
Last Year
  • Watch event: 12
  • Delete event: 1
  • Issue comment event: 1
  • Push event: 1
  • Pull request event: 5
  • Fork event: 1
  • Create event: 3

Issues and Pull Requests

Last synced: 6 months ago

All Time
  • Total issues: 0
  • Total pull requests: 4
  • Average time to close issues: N/A
  • Average time to close pull requests: about 1 month
  • Total issue authors: 0
  • Total pull request authors: 2
  • Average comments per issue: 0
  • Average comments per pull request: 0.5
  • Merged pull requests: 1
  • Bot issues: 0
  • Bot pull requests: 3
Past Year
  • Issues: 0
  • Pull requests: 4
  • Average time to close issues: N/A
  • Average time to close pull requests: about 1 month
  • Issue authors: 0
  • Pull request authors: 2
  • Average comments per issue: 0
  • Average comments per pull request: 0.5
  • Merged pull requests: 1
  • Bot issues: 0
  • Bot pull requests: 3
Top Authors
Issue Authors
  • yash-srivastava19 (2)
Pull Request Authors
  • dependabot[bot] (4)
  • rolyatmax (1)
  • yash-srivastava19 (1)
Top Labels
Issue Labels
enhancement (2) help wanted (2)
Pull Request Labels
dependencies (4) python (1) documentation (1) help wanted (1)

Packages

  • Total packages: 1
  • Total downloads:
    • pypi 20 last-month
  • Total dependent packages: 0
  • Total dependent repositories: 0
  • Total versions: 2
  • Total maintainers: 1
pypi.org: arrakis-mi

A mechanistic interpretability library for nerds.

  • Versions: 2
  • Dependent Packages: 0
  • Dependent Repositories: 0
  • Downloads: 20 Last month
Rankings
Dependent packages count: 10.6%
Average: 35.2%
Dependent repos count: 59.9%
Maintainers (1)
Last synced: 7 months ago

Dependencies

docs/requirements.txt pypi
  • myst-parser *
  • sphinx-book-theme *
poetry.lock pypi
  • certifi 2024.6.2
  • charset-normalizer 3.3.2
  • colorama 0.4.6
  • contourpy 1.2.1
  • cycler 0.12.1
  • filelock 3.15.4
  • fonttools 4.53.0
  • fsspec 2024.6.0
  • huggingface-hub 0.23.4
  • idna 3.7
  • jinja2 3.1.4
  • kiwisolver 1.4.5
  • markupsafe 2.1.5
  • matplotlib 3.9.1
  • mpmath 1.3.0
  • networkx 3.3
  • numpy 1.26.4
  • nvidia-cublas-cu12 12.1.3.1
  • nvidia-cuda-cupti-cu12 12.1.105
  • nvidia-cuda-nvrtc-cu12 12.1.105
  • nvidia-cuda-runtime-cu12 12.1.105
  • nvidia-cudnn-cu12 8.9.2.26
  • nvidia-cufft-cu12 11.0.2.54
  • nvidia-curand-cu12 10.3.2.106
  • nvidia-cusolver-cu12 11.4.5.107
  • nvidia-cusparse-cu12 12.1.0.106
  • nvidia-nccl-cu12 2.19.3
  • nvidia-nvjitlink-cu12 12.5.40
  • nvidia-nvtx-cu12 12.1.105
  • packaging 24.1
  • pandas 2.2.2
  • pillow 10.4.0
  • pyparsing 3.1.2
  • python-dateutil 2.9.0.post0
  • pytz 2024.1
  • pyyaml 6.0.1
  • regex 2024.5.15
  • requests 2.32.3
  • safetensors 0.4.3
  • seaborn 0.13.2
  • six 1.16.0
  • sympy 1.12.1
  • tokenizers 0.19.1
  • torch 2.2.1
  • tqdm 4.66.4
  • transformers 4.41.2
  • triton 2.2.0
  • typing-extensions 4.12.2
  • tzdata 2024.1
  • urllib3 2.2.2
pyproject.toml pypi
  • numpy <2
  • python ^3.11
  • seaborn ^0.13.2
  • torch 2.2.1
  • transformers ^4.41.2