imodelsx

Interpret text data using LLMs (scikit-learn compatible).

https://github.com/csinva/imodelsx

Science Score: 64.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org, nature.com
  • Committers with academic emails
    1 of 5 committers (20.0%) from academic institutions
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (9.2%) to scientific vocabulary

Keywords

ai deep-learning explainability huggingface interpretability language-model machine-learning ml natural-language-processing natural-language-understanding neural-network pytorch scikit-learn text text-classification transformer-models xai
Last synced: 6 months ago · JSON representation ·

Repository

Interpret text data using LLMs (scikit-learn compatible).

Basic Info
Statistics
  • Stars: 170
  • Watchers: 6
  • Forks: 28
  • Open Issues: 5
  • Releases: 0
Topics
ai deep-learning explainability huggingface interpretability language-model machine-learning ml natural-language-processing natural-language-understanding neural-network pytorch scikit-learn text text-classification transformer-models xai
Created over 3 years ago · Last pushed 6 months ago
Metadata Files
Readme License Citation

readme.md

Scikit-learn friendly library to explain, predict, and steer text models/data.
Also a bunch of utilities for getting started with text data.

📖 demo notebooks

Explainable modeling/steering

| Model | Reference | Output | Description | | :-------------------------- | ------------------------------------------------------------ | ------- | ------------------------------------------------------------ | | Tree-Prompt | 🗂️, 🔗, 📄, 📖, | Explanation
+ Steering | Generates a tree of prompts to
steer an LLM (Official) | | iPrompt | 🗂️, 🔗, 📄, 📖 | Explanation
+ Steering | Generates a prompt that
explains patterns in data (Official) | | AutoPrompt | ㅤㅤ🗂️, 🔗, 📄 | Explanation
+ Steering | Find a natural-language prompt
using input-gradients| | D3 | 🗂️, 🔗, 📄, 📖 | Explanation | Explain the difference between two distributions | | SASC | ㅤㅤ🗂️, 🔗, 📄 | Explanation | Explain a black-box text module
using an LLM (Official) | | Aug-Linear | 🗂️, 🔗, 📄, 📖 | Linear model | Fit better linear model using an LLM
to extract embeddings (Official) | | Aug-Tree | 🗂️, 🔗, 📄, 📖 | Decision tree | Fit better decision tree using an LLM
to expand features (Official) | | QAEmb | 🗂️, 🔗, 📄, 📖 | Explainable
embedding | Generate interpretable embeddings
by asking LLMs questions (Official) | | KAN | 🗂️, 🔗, 📄, 📖 | Small
network | Fit 2-layer Kolmogorov-Arnold network |

📖Demo notebooks   🗂️ Doc   🔗 Reference code   📄 Research paper
⌛ We plan to support other interpretable algorithms like RLPrompt, CBMs, and NBDT. If you want to contribute an algorithm, feel free to open a PR 😄

General utilities

| Model | Reference | | :-------------------------- | ------------------------------------------------------------ | | 🗂️ LLM wrapper| Easily call different LLMs | | 🗂️ Dataset wrapper| Download minimially processed huggingface datasets | | 🗂️ Bag of Ngrams | Learn a linear model of ngrams | | 🗂️ Linear Finetune | Finetune a single linear layer on top of LLM embeddings |

Quickstart

Installation: pip install imodelsx (or, for more control, clone and install from source)

Demos: see the demo notebooks

Natural-language explanations

Tree-prompt

```python from imodelsx import TreePromptClassifier import datasets import numpy as np from sklearn.tree import plot_tree import matplotlib.pyplot as plt

set up data

rng = np.random.defaultrng(seed=42) dsettrain = datasets.loaddataset('rottentomatoes')['train'] dsettrain = dsettrain.select(rng.choice( len(dsettrain), size=100, replace=False)) dsetval = datasets.loaddataset('rottentomatoes')['validation'] dsetval = dsetval.select(rng.choice( len(dset_val), size=100, replace=False))

set up arguments

prompts = [ "This movie is", " Positive or Negative? The movie was", " The sentiment of the movie was", " The plot of the movie was really", " The acting in the movie was", ] verbalizer = {0: " Negative.", 1: " Positive."} checkpoint = "gpt2"

fit model

m = TreePromptClassifier( checkpoint=checkpoint, prompts=prompts, verbalizer=verbalizer, cachepromptfeaturesdir=None, # 'cachepromptfeaturesdir/gp2', ) m.fit(dsettrain["text"], dsettrain["label"])

compute accuracy

preds = m.predict(dsetval['text']) print('\nTree-Prompt acc (val) ->', np.mean(preds == dsetval['label'])) # -> 0.7

compare to accuracy for individual prompts

for i, prompt in enumerate(prompts): print(i, prompt, '->', m.promptaccs[i]) # -> 0.65, 0.5, 0.5, 0.56, 0.51

visualize decision tree

plottree( m.clf, fontsize=10, featurenames=m.featurenames, classnames=list(verbalizer.values()), filled=True, ) plt.show() ```

iPrompt

```python from imodelsx import explaindatasetiprompt, getaddtwonumbersdataset

get a simple dataset of adding two numbers

inputstrings, outputstrings = getaddtwonumbersdataset(numexamples=100) for i in range(5): print(repr(inputstrings[i]), repr(output_strings[i]))

explain the relationship between the inputs and outputs

with a natural-language prompt string

prompts, metadata = explaindatasetiprompt( inputstrings=inputstrings, outputstrings=outputstrings, checkpoint='EleutherAI/gpt-j-6B', # which language model to use numlearnedtokens=3, # how long of a prompt to learn nshots=3, # shots per example nepochs=15, # how many epochs to search verbose=0, # how much to print llmfloat16=True, # whether to load the model in float16

)

prompts is a list of found natural-language prompt strings ```

D3 (DescribeDistributionalDifferences)

python from imodelsx import explain_dataset_d3 hypotheses, hypothesis_scores = explain_dataset_d3( pos=positive_samples, # List[str] of positive examples neg=negative_samples, # another List[str] num_steps=100, num_folds=2, batch_size=64, )

SASC

Here, we explain a module rather than a dataset

```python from imodelsx import explainmodulesasc

a toy module that responds to the length of a string

mod = lambda strlist: np.array([len(s) for s in strlist])

a toy dataset where the longest strings are animals

textstrlist = ["red", "blue", "x", "1", "2", "hippopotamus", "elephant", "rhinoceros"] explanationdict = explainmodulesasc( textstr_list, mod, ngrams=1, ) ```

Aug-imodels

Use these just a like a scikit-learn model. During training, they fit better features via LLMs, but at test-time they are extremely fast and completely transparent.

```python from imodelsx import AugLinearClassifier, AugTreeClassifier, AugLinearRegressor, AugTreeRegressor import datasets import numpy as np

set up data

dset = datasets.loaddataset('rottentomatoes')['train'] dset = dset.select(np.random.choice(len(dset), size=300, replace=False)) dsetval = datasets.loaddataset('rottentomatoes')['validation'] dsetval = dsetval.select(np.random.choice(len(dsetval), size=300, replace=False))

fit model

m = AugLinearClassifier( checkpoint='textattack/distilbert-base-uncased-rotten-tomatoes', ngrams=2, # use bigrams ) m.fit(dset['text'], dset['label'])

predict

preds = m.predict(dsetval['text']) print('accval', np.mean(preds == dset_val['label']))

interpret

print('Total ngram coefficients: ', len(m.coefsdict)) print('Most positive ngrams') for k, v in sorted(m.coefsdict.items(), key=lambda item: item[1], reverse=True)[:8]: print('\t', k, round(v, 2)) print('Most negative ngrams') for k, v in sorted(m.coefsdict.items(), key=lambda item: item[1])[:8]: print('\t', k, round(v, 2)) ```

KAN

```python import imodelsx from sklearn.datasets import makeclassification, makeregression from sklearn.metrics import accuracy_score import numpy as np

X, y = makeclassification(nsamples=5000, nfeatures=5, ninformative=3) model = imodelsx.KANClassifier(hiddenlayersize=64, device='cpu', regularizeactivation=1.0, regularizeentropy=1.0) model.fit(X, y) ypred = model.predict(X) print('Test acc', accuracyscore(y, y_pred))

now try regression

X, y = makeregression(nsamples=5000, nfeatures=5, ninformative=3) model = imodelsx.kan.KANRegressor(hiddenlayersize=64, device='cpu', regularizeactivation=1.0, regularizeentropy=1.0) model.fit(X, y) ypred = model.predict(X) print('Test correlation', np.corrcoef(y, ypred.flatten())[0, 1]) ```

General utilities

Easy baselines

Easy-to-fit baselines that follow the sklearn API.

```python from imodelsx import LinearFinetuneClassifier, LinearNgramClassifier

fit a simple one-layer finetune on top of LLM embeddings

m = LinearFinetuneClassifier( checkpoint='distilbert-base-uncased', ) m.fit(dset['text'], dset['label']) preds = m.predict(dsetval['text']) acc = (preds == dsetval['label']).mean() print('validation acc', acc) ```

LLM wrapper

Easy API for calling different language models with caching (much more lightweight than langchain).

```python import imodelsx.llm

supports any huggingface checkpoint or openai checkpoint (including chat models)

llm = imodelsx.llm.getllm( checkpoint="gpt2-xl", # text-davinci-003, gpt-3.5-turbo, ... CACHEDIR=".cache", ) out = llm("May the Force be") llm("May the Force be") # when computing the same string again, uses the cache ```

Data wrapper

API for loading huggingface datasets with basic preprocessing. ```python import imodelsx.data dset, datasetkeytext = imodelsx.data.loadhuggingfacedataset('ag_news')

Ensures that dset has a split named 'train' and 'validation',

and that the input data is contained for each split in a column given by {datasetkeytext}

```

Related work

  • imodels package (JOSS 2021 github) - interpretable ML package for concise, transparent, and accurate predictive modeling (sklearn-compatible).
  • Rethinking Interpretability in the Era of Large Language Models (arXiv 2024 pdf) - overview of using LLMs to interpret datasets and yield natural-language explanations
  • Experiments in using clinical rule development: github
  • Experiments in automatically generating brain explanations: github
  • Interpretation regularization (ICML 2020 pdf, github) - penalizes CD / ACD scores during training to make models generalize better

Owner

  • Name: Chandan Singh
  • Login: csinva
  • Kind: user
  • Location: Microsoft research
  • Company: Senior researcher

Senior researcher @Microsoft interpreting ML models in science and medicine. PhD from UC Berkeley.

Citation (citation.cff)

cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- family-names: "Singh"
  given-names: "Chandan"
- family-names: "Gao"
  given-names: "Jianfeng"
title: "Emb-GAM: an Interpretable and Efficient Predictor using Pre-trained Language Models"
journal: "arXiv preprint arXiv:2209.11799"
date-released: 2022-09-26
url: "https://arxiv.org/abs/2209.11799"

GitHub Events

Total
  • Release event: 1
  • Issues event: 2
  • Watch event: 12
  • Issue comment event: 1
  • Push event: 24
  • Pull request event: 1
  • Fork event: 2
Last Year
  • Release event: 1
  • Issues event: 2
  • Watch event: 12
  • Issue comment event: 1
  • Push event: 24
  • Pull request event: 1
  • Fork event: 2

Committers

Last synced: 9 months ago

All Time
  • Total Commits: 330
  • Total Committers: 5
  • Avg Commits per committer: 66.0
  • Development Distribution Score (DDS): 0.052
Past Year
  • Commits: 20
  • Committers: 2
  • Avg Commits per committer: 10.0
  • Development Distribution Score (DDS): 0.05
Top Committers
Name Email Commits
Chandan Singh c****h@b****u 313
Jack Morris j****2@g****m 10
divyanshuaggarwal d****l@g****m 3
arminaskari a****i@g****m 3
RAFAEL RODRIGUES r****d@g****m 1
Committer Domains (Top 20 + Academic)

Issues and Pull Requests

Last synced: 6 months ago

All Time
  • Total issues: 8
  • Total pull requests: 7
  • Average time to close issues: 4 months
  • Average time to close pull requests: about 5 hours
  • Total issue authors: 7
  • Total pull request authors: 4
  • Average comments per issue: 1.0
  • Average comments per pull request: 0.43
  • Merged pull requests: 7
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 3
  • Pull requests: 0
  • Average time to close issues: N/A
  • Average time to close pull requests: N/A
  • Issue authors: 3
  • Pull request authors: 0
  • Average comments per issue: 0.33
  • Average comments per pull request: 0
  • Merged pull requests: 0
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • valeman (2)
  • PapadopoulosDimitrios (1)
  • edmondja (1)
  • Cxrxnnx (1)
  • clclclaiggg (1)
  • SalvatoreRa (1)
  • mvdoc (1)
Pull Request Authors
  • csinva (2)
  • divyanshuaggarwal (2)
  • rafa-rod (2)
  • arminaskari (2)
Top Labels
Issue Labels
Pull Request Labels

Packages

  • Total packages: 3
  • Total downloads:
    • pypi 2,365 last-month
  • Total dependent packages: 0
    (may contain duplicates)
  • Total dependent repositories: 1
    (may contain duplicates)
  • Total versions: 36
  • Total maintainers: 1
proxy.golang.org: github.com/csinva/imodelsx
  • Versions: 2
  • Dependent Packages: 0
  • Dependent Repositories: 0
Rankings
Dependent packages count: 6.5%
Average: 6.7%
Dependent repos count: 7.0%
Last synced: 6 months ago
proxy.golang.org: github.com/csinva/imodelsX
  • Versions: 2
  • Dependent Packages: 0
  • Dependent Repositories: 0
Rankings
Dependent packages count: 6.5%
Average: 6.7%
Dependent repos count: 7.0%
Last synced: 6 months ago
pypi.org: imodelsx

Library to explain a dataset in natural language.

  • Versions: 32
  • Dependent Packages: 0
  • Dependent Repositories: 1
  • Downloads: 2,365 Last month
Rankings
Dependent packages count: 7.4%
Stargazers count: 8.8%
Forks count: 10.6%
Average: 14.6%
Dependent repos count: 22.2%
Downloads: 24.1%
Maintainers (1)
Last synced: 6 months ago