arrakis-mi
Arrakis is a library to conduct, track and visualize mechanistic interpretability experiments.
Science Score: 44.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
○Academic publication links
-
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (11.2%) to scientific vocabulary
Keywords
Repository
Arrakis is a library to conduct, track and visualize mechanistic interpretability experiments.
Basic Info
- Host: GitHub
- Owner: yash-srivastava19
- Language: Jupyter Notebook
- Default Branch: main
- Homepage: https://arrakis-mi.readthedocs.io/en/latest/README.html
- Size: 3.53 MB
Statistics
- Stars: 31
- Watchers: 1
- Forks: 3
- Open Issues: 6
- Releases: 0
Topics
Metadata Files
README.md
Arrakis - A Mechanistic Interpretability Tool
Interpretability is a relatively new field where everyday something new is happening. Mechanistic Interpretability is one approach to reverse engineer neural networks and understand what is happening inside these black-box models.
Mechanistic Interpretability is a really exciting subfield of alignment, and recently, a lot has been happening in this field - especially at Anthropic. To look at the goal of MI at Anthropic, read this post. The core operation involved in MI is loading a model, looking at its weights and activations, doing some operations on them, and producing results.
I made Arrakis to deeply understand Transformer-based models (maybe in the future I will try to be model-agnostic). The first thought that should come to mind is Why not use Transformer Lens? Neel Nanda has already made significant progress in that. I made Arrakis as I wanted to have a library which can do more than just get the activations - I wanted a more complete library where researchers can do experiments, and track their progress. Think of Arrakis as a complete suite to conduct MI experiments, where I try to get the best of both Transformer Lens and Garcon. More features will be added as I understand how to make this library more useful for the community, and I need feedback for that.
Tools and Decomposibility
Regardless of what research project you are working on, if you are not keeping track of things, it gets messy really easily. In a field like MI, where you are constantly looking at all different weights and biases, and there are a lot of moving parts - it gets overwhelming fairly easily. I've experienced this personally, and being someone who is obsessed with reducing experimentation time and getting results quickly, I wanted to have a complete suite which makes my workload easy.
Arrakis is made so that this doesn't happen. The core principle behind Arrakis is decomposibility. Do all experiments with plug-and-play tools (will be much clearer in the walkthrough). This makes experimentation really flexible, and at the same time, Arrakis keeps track of different versions of the experiments by default. Everything in Arrakis is made in this plug-and-play fashion. I have even incorporated a graphing library (on top of several popular libraries) to make graphing a lot easier.
I really want feedback and contributions on this project so that this can be adapted by the community at large.
Arrakis Walkthrough
Let's understand how to conduct a small experiment in Arrakis. It is easy, reproducible and very easy to implement.
Step 1: Install the package
All the dependencies of the project are maintained through poetry.
python
pip install arrakis-mi
Step 2: Create HookedAutoModel
HookedAutoModel offers a convenient way to import models from Huggingface directly (with Hooks). Everything just works out of the box. First, create a HookedConfig for the model you want to support with the required parameters. Then, create a HookedAutoModel from the config. As of now, these models are supported:
python
[
"gpt2",
"gpt-neo",
"gpt-neox",
"llama",
"gemma",
"phi3",
"qwen2",
"mistral",
"stable-lm",
]
As mentioned, the core idea behind Arrakis is decomposibility, so a HookedAutoModel is a wrapper around Huggingface PreTrainedModel class, with a single plug-and-play decorator for the forward pass. All the model probing happens behind the scenes, and is pre-configured.
```python from arrakis.src.corearrakis.activationcache import *
config = HookedAutoConfig(name="llama", vocabsize=50256, hiddensize=8, intermediatesize=2, numhiddenlayers=4, numattentionheads=4, numkeyvalueheads=4)
model = HookedAutoModel(config)
```
Step 3: Set up Interpretability Bench
At its core, the whole purpose of Arrakis is to conduct MI experiments. After installing, derive from the BaseInterpretabilityBench and instantiate an object (exp in this case). This object provides a lot of functionality out-of the box based on the "tool" you want to use for the experiment, and have access to the functions that the tool provides. You can also create your own tool (read about that here )
```python from arrakis.src.corearrakis.basebench import BaseInterpretabilityBench
class MIExperiment(BaseInterpretabilityBench): def init(self, model, savedir="experiments"): super().init(model, savedir) self.tools.update({"custom": CustomFunction(model)})
exp = MIExperiment(model) ```
Apart from access to MI tools, the object also provides you a convenient way to log your experiments. To log your experiments, just decorate the function you are working with @exp.log_experiment, and that is pretty much it. The function creates a local version control on the contents of the function and stores it locally. You can run many things in parallel, and the version control helps you keep track of it.
```python
Step 1: Create a function where you can do operations on the model.
@exp.logexperiment # This is pretty much it. This will log the experiment. def attentionexperiment(): print("This is a placeholder for the experiment. Use as is.") return 4
Step 2: Then, you run the function, get results. This starts the experiment.
attention_experiment()
Step 3: Then, we will look at some of the things that logs keep a track of
l = exp.listversions("attentionexperiment") # This gives the hash of the content of the experiment. print("This is the version hash of the experiment: ", l)
Step 4: You can also get the content of the experiment from the saved json.
print(exp.getversion("attentionexperiment", l[0])['source']) # This gives the content of the experiment.
``
Apart from these tools, there are also@exp.profilemodel(to profile how much resources the model is using) and@exp.testhypothesis` (to test hypotheses). Support of more tools will be added as I get more feedback from the community.
Step 4: Create your experiments
By default, Arrakis provides a lot of Anthropic's interpretability experiments (Monosemanticity, Residual Decomposition, Read Write Analysis and a lot more). These are provided as tools, so in your experiments, you can plug and play with them and conduct your experiments. Here's an example of how you can do that. ```python
Making functions for Arrakis to use is pretty easy. Let's look it in action.
Step 1: Create a function where you can do operations on the model. Think of all the tools you might need for it.
Step 2: Use the @exp.use_tools decorator on it, with additional arg of the tool.
Step 3: The extra argument gives you access to the function. Done.
@exp.usetools("writeread") # use the exp.use_tools() decorator.
def readwriteanalysis(readlayeridx, writelayeridx, srcidx, writeread=None): # pass an additional argument.
# Multi-hop attention (write-read)
# use the extra argument as a tool.
write_heads = write_read.identify_write_heads(read_layer_idx)
read_heads = write_read.identify_read_heads(write_layer_idx, dim_idx=src_idx)
return {
"write_heads": write_heads,
"read_heads": read_heads
}
print(readwriteanalysis(0, 1, 0)) # Perfecto!
```
Step 5: Visualize the Results
Generating plots is Arrakis is also plug-and-play, just add the decorator and plots are generated by default. Read more about the graphing docs here ```python
from arrakis.src.graph.base_graph import *
Step 1: Create a function where you can want to draw plot.
Step2: Use the @exp.plotresults decorator on it(set the plotting lib), with additional arg of the plot spec. Pass inputids here as well(have to think on this)
Step3: The extra argument gives you access to the fig. Done.
exp.setplottinglib(MatplotlibWrapper) # Set the plotting library.
@exp.plotresults(PlotSpec(plottype = "attention", datakeys = "h.1.attn.cattn"), inputids=inputids) # use the exp.plot_results() decorator.
def attention_heatmap(fig=None): # pass an additional argument.
return fig
attention_heatmap() # Done. plt.show()
``
These are three upper level classes in Arrakis. One is theInterpretabilityBenchwhere you conduct experiments, the second is thecore_arrakiswhere I've implemented some common tests for Transformer based model and the third is theGraphing`.
List of Tools
There is a lot of what's happening inside the core_arrakis. There are a lot of tools that we can use, which we'll deal with one by one. We'll understand what they do and how to use Arrakis to test them. These tools are supported as of now(please contribute more!)
- Attention Head Composition
- Attention Tools
- Causal Tracing Intervention
- Knowledge Graph Extractor
- Knowledge Prober
- Logit Attribution
- Logit Lens
- Read Write Heads
- Residual Decomposition
- Residual Tools
- Sparsity Analyzer
- Superposition Disentangler
Go to their respective pages and read about what they mean and how to use Arrakis to conduct experiments.
Extending Arrakis
Apart from all of these tool, it is easy to develop tools on your own which you can use for your experiment. These are the steps to do so: - Step 1: Make a class which inherits from the BaseInterpretabilityTool ```python from arrakis.src.corearrakis.baseinterpret import BaseInterpretabilityTool
class CustomTool(BaseInterpretabilityTool): def init(self, model): super().init(model) self.model = model
def custom_function(self, *args, **kwargs):
# do some computations
pass
def another_custom_function(self, *args, **kwargs):
# do another calcualtions
pass
``
The attributemodelis a wrapper around HuggingfacePreTrainedModel` with many additional features which makes easier for experimentation purposes. The reference for model is given here. Write your function that utilizes the ActivationCache and get the intermediate activations.
- Step 2: In the derived class from
BaseInterpretabilityBench, add your custom tool in the following manner.
```python from src.bench.base_bench import BaseInterpretabilityBench
Import the custom tool here.
class ExperimentBench(BaseInterpretabilityBench): def init(self, model, savedir="experiments"): super().init(model, savedir) self.tools.update({"custom": CustomTool(model)})
exp = ExperimentBench(model) # where model is an instance of HookedAutoModel ``` And that is pretty much it. Now, in order to use it in a function, just do the following:
```python
@exp.usetools("custom") def testcustomfunction(args, kwargs, custom): # the final argument should be the same name as the tool key. custom.customfunction() custom.anothercustomfunction()
testcustomfunction(args, kwargs) ``` Adding your own tool is really easy in Arrakis. Read the API reference guide to see how to implement your own functions. Open a PR for tools that are not implemented and I can add it quickly.
How to Start?
For just starting out, consider going through the files demo.ipynb to get an overview of the library, test_graphs.py and test_new_model.py to test the model and the graphs(run from the command line)
Owner
- Name: Yash Srivastava
- Login: yash-srivastava19
- Kind: user
- Location: Bhopal, Madhya Pradesh
- Company: NIT Warangal
- Website: https://yash-sri.xyz
- Twitter: Yaaaaaashhh
- Repositories: 55
- Profile: https://github.com/yash-srivastava19
A Genius, Shy and Broke bloke.
Citation (CITATION.cff)
cff-version: 1.2.0 message: "If you use Arrakis in your work, please cite it as below." authors: - family-names: "Yashovardhan" given-names: "Srivastava" title: "arrakis" version: 1.0.0 date-released: 2024-07-26 url: "https://github.com/yash-srivastava19/arrakis"
GitHub Events
Total
- Watch event: 12
- Delete event: 1
- Issue comment event: 1
- Push event: 1
- Pull request event: 5
- Fork event: 1
- Create event: 3
Last Year
- Watch event: 12
- Delete event: 1
- Issue comment event: 1
- Push event: 1
- Pull request event: 5
- Fork event: 1
- Create event: 3
Issues and Pull Requests
Last synced: 6 months ago
All Time
- Total issues: 0
- Total pull requests: 4
- Average time to close issues: N/A
- Average time to close pull requests: about 1 month
- Total issue authors: 0
- Total pull request authors: 2
- Average comments per issue: 0
- Average comments per pull request: 0.5
- Merged pull requests: 1
- Bot issues: 0
- Bot pull requests: 3
Past Year
- Issues: 0
- Pull requests: 4
- Average time to close issues: N/A
- Average time to close pull requests: about 1 month
- Issue authors: 0
- Pull request authors: 2
- Average comments per issue: 0
- Average comments per pull request: 0.5
- Merged pull requests: 1
- Bot issues: 0
- Bot pull requests: 3
Top Authors
Issue Authors
- yash-srivastava19 (2)
Pull Request Authors
- dependabot[bot] (4)
- rolyatmax (1)
- yash-srivastava19 (1)
Top Labels
Issue Labels
Pull Request Labels
Packages
- Total packages: 1
-
Total downloads:
- pypi 20 last-month
- Total dependent packages: 0
- Total dependent repositories: 0
- Total versions: 2
- Total maintainers: 1
pypi.org: arrakis-mi
A mechanistic interpretability library for nerds.
- Homepage: https://github.com/yash-srivastava19/arrakis
- Documentation: https://arrakis-mi.readthedocs.io/en/latest/README.html
- License: MIT
-
Latest release: 0.1.1
published over 1 year ago
Rankings
Maintainers (1)
Dependencies
- myst-parser *
- sphinx-book-theme *
- certifi 2024.6.2
- charset-normalizer 3.3.2
- colorama 0.4.6
- contourpy 1.2.1
- cycler 0.12.1
- filelock 3.15.4
- fonttools 4.53.0
- fsspec 2024.6.0
- huggingface-hub 0.23.4
- idna 3.7
- jinja2 3.1.4
- kiwisolver 1.4.5
- markupsafe 2.1.5
- matplotlib 3.9.1
- mpmath 1.3.0
- networkx 3.3
- numpy 1.26.4
- nvidia-cublas-cu12 12.1.3.1
- nvidia-cuda-cupti-cu12 12.1.105
- nvidia-cuda-nvrtc-cu12 12.1.105
- nvidia-cuda-runtime-cu12 12.1.105
- nvidia-cudnn-cu12 8.9.2.26
- nvidia-cufft-cu12 11.0.2.54
- nvidia-curand-cu12 10.3.2.106
- nvidia-cusolver-cu12 11.4.5.107
- nvidia-cusparse-cu12 12.1.0.106
- nvidia-nccl-cu12 2.19.3
- nvidia-nvjitlink-cu12 12.5.40
- nvidia-nvtx-cu12 12.1.105
- packaging 24.1
- pandas 2.2.2
- pillow 10.4.0
- pyparsing 3.1.2
- python-dateutil 2.9.0.post0
- pytz 2024.1
- pyyaml 6.0.1
- regex 2024.5.15
- requests 2.32.3
- safetensors 0.4.3
- seaborn 0.13.2
- six 1.16.0
- sympy 1.12.1
- tokenizers 0.19.1
- torch 2.2.1
- tqdm 4.66.4
- transformers 4.41.2
- triton 2.2.0
- typing-extensions 4.12.2
- tzdata 2024.1
- urllib3 2.2.2
- numpy <2
- python ^3.11
- seaborn ^0.13.2
- torch 2.2.1
- transformers ^4.41.2