ReLax
ReLax: Efficient and Scalable Recourse Explanation Benchmarking using JAX - Published in JOSS (2024)
Science Score: 95.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
○CITATION.cff file
-
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
✓DOI references
Found 4 DOI reference(s) in README and JOSS metadata -
✓Academic publication links
Links to: arxiv.org, joss.theoj.org -
✓Committers with academic emails
1 of 7 committers (14.3%) from academic institutions -
○Institutional organization owner
-
✓JOSS paper metadata
Published in Journal of Open Source Software
Repository
Recourse Explanation Library in JAX
Basic Info
- Host: GitHub
- Owner: BirkhoffG
- License: apache-2.0
- Language: Jupyter Notebook
- Default Branch: master
- Homepage: https://birkhoffg.github.io/jax-relax/
- Size: 17 MB
Statistics
- Stars: 6
- Watchers: 1
- Forks: 1
- Open Issues: 6
- Releases: 10
Metadata Files
README.md
ReLax
Overview | Installation | Tutorials | Documentation | Citing ReLax
Overview
ReLax (Recourse Explanation Library in Jax) is an
efficient and scalable benchmarking library for recourse and
counterfactual explanations, built on top of
jax. By leveraging language
primitives such as vectorization, parallelization, and
just-in-time compilation in
jax, ReLax offers massive
speed improvements in generating individual (or local) explanations for
predictions made by Machine Learning algorithms.
Some of the key features are as follows:
🏃 Fast and scalable recourse generation.
🚀 Accelerated over
cpu,gpu,tpu.🪓 Comprehensive set of recourse methods implemented for benchmarking.
👐 Customizable API to enable the building of entire modeling and interpretation pipelines for new recourse algorithms.
Installation
``` bash pip install jax-relax
Or install the latest version of jax-relax
pip install git+https://github.com/BirkhoffG/jax-relax.git ```
To futher unleash the power of accelerators (i.e., GPU/TPU), we suggest
to first install this library via pip install jax-relax. Then, follow
steps in the official install
guidelines to install the
right version for GPU or TPU.
Dive into ReLax
ReLax is a recourse explanation library for explaining (any) JAX-based
ML models. We believe that it is important to give users flexibility to
choose how to use ReLax. You can
- only use methods implemeted in
ReLax(as a recourse methods library); - build a pipeline using
ReLaxto define data module, training ML models, and generating CF explanation (for constructing recourse benchmarking pipeline).
ReLax as a Recourse Explanation Library
We introduce basic use cases of using methods in ReLax to generate
recourse explanations. For more advanced usages of methods in ReLax,
See this tutorials.
python
from relax.methods import VanillaCF
from relax import DataModule, MLModule, generate_cf_explanations, benchmark_cfs
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import functools as ft
import jax
Let’s first generate synthetic data:
python
xs, ys = make_classification(n_samples=1000, n_features=10, random_state=42)
train_xs, test_xs, train_ys, test_ys = train_test_split(xs, ys, random_state=42)
Next, we fit an MLP model for this data. Note that this model can be any
model implmented in JAX. We will use the
MLModule
in ReLax as an example.
python
model = MLModule()
model.train((train_xs, train_ys), epochs=10, batch_size=64)
Generating recourse explanations are straightforward. We can simply call
generate_cf of an implemented recourse method to generate one
recourse explanation:
python
vcf = VanillaCF(config={'n_steps': 1000, 'lr': 0.05})
cf = vcf.generate_cf(test_xs[0], model.pred_fn)
assert cf.shape == test_xs[0].shape
Or generate a bunch of recourse explanations with jax.vmap:
python
generate_fn = ft.partial(vcf.generate_cf, pred_fn=model.pred_fn)
cfs = jax.vmap(generate_fn)(test_xs)
assert cfs.shape == test_xs.shape
ReLax for Building Recourse Explanation Pipelines
The above example illustrates the usage of the decoupled relax.methods
to generate recourse explanations. However, users are required to write
boilerplate code for tasks such as data preprocessing, model training,
and generating recourse explanations with feature constraints.
ReLax additionally offers a one-liner framework, streamlining the
process and helping users in building a standardized pipeline for
generating recourse explanations. You can write three lines of code to
benchmark recourse explanations:
python
data_module = DataModule.from_numpy(xs, ys)
exps = generate_cf_explanations(vcf, data_module, model.pred_fn)
benchmark_cfs([exps])
See Getting Started with
ReLax
for an end-to-end example of using ReLax.
Supported Recourse Methods
ReLax currently provides implementations of 9 recourse explanation
methods.
| Method | Type | Paper Title | Ref |
|--------------------------------------------------------------------------------------------|-----------------|------------------------------------------------------------------------------------------------|-------------------------------------------|
| VanillaCF | Non-Parametric | Counterfactual Explanations without Opening the Black Box: Automated Decisions and the GDPR. | [1] |
| DiverseCF | Non-Parametric | Explaining Machine Learning Classifiers through Diverse Counterfactual Explanations. | [2] |
| ProtoCF | Semi-Parametric | Interpretable Counterfactual Explanations Guided by Prototypes. | [3] |
| CounterNet | Parametric | CounterNet: End-to-End Training of Prediction Aware Counterfactual Explanations. | [4] |
| GrowingSphere | Non-Parametric | Inverse Classification for Comparison-based Interpretability in Machine Learning. | [5] |
| CCHVAE | Semi-Parametric | Learning Model-Agnostic Counterfactual Explanations for Tabular Data. | [6] |
| VAECF | Parametric | Preserving Causal Constraints in Counterfactual Explanations for Machine Learning Classifiers. | [7] |
| CLUE | Semi-Parametric | Getting a CLUE: A Method for Explaining Uncertainty Estimates. | [8] |
| L2C | Parametric | Feature-based Learning for Diverse and Privacy-Preserving Counterfactual Explanations | [9] |
Citing ReLax
To cite this repository:
latex
@software{relax2023github,
author = {Hangzhi Guo and Xinchang Xiong and Amulya Yadav},
title = {{R}e{L}ax: Recourse Explanation Library in Jax},
url = {http://github.com/birkhoffg/jax-relax},
version = {0.2.0},
year = {2023},
}
Owner
- Name: Hangzhi Guo
- Login: BirkhoffG
- Kind: user
- Company: Penn State University
- Website: https://birkhoffg.github.io
- Twitter: BirkhoffGuo
- Repositories: 4
- Profile: https://github.com/BirkhoffG
Ph.D. Student at Penn State University
JOSS Publication
ReLax: Efficient and Scalable Recourse Explanation Benchmarking using JAX
Authors
Duke University, Durham, NC, USA
Penn State University, University Park, PA, USA
Tags
JAX machine learning interpretability counterfactual explanation recourseGitHub Events
Total
- Create event: 3
- Release event: 1
- Issues event: 1
- Issue comment event: 2
- Push event: 14
- Pull request event: 3
Last Year
- Create event: 3
- Release event: 1
- Issues event: 1
- Issue comment event: 2
- Push event: 14
- Pull request event: 3
Committers
Last synced: 7 months ago
Top Committers
| Name | Commits | |
|---|---|---|
| BirkhoffG | 2****G | 389 |
| Xinchang Xiong | 6****e | 29 |
| Xinchang Xiong | c****k@P****l | 29 |
| Firdaus Choudhury | f****6@p****u | 5 |
| Praneyg | p****3@g****m | 2 |
| Xinchang Xiong | 6****B | 1 |
| root | r****t@i****l | 1 |
Committer Domains (Top 20 + Academic)
Issues and Pull Requests
Last synced: 6 months ago
All Time
- Total issues: 25
- Total pull requests: 25
- Average time to close issues: 20 days
- Average time to close pull requests: 12 days
- Total issue authors: 2
- Total pull request authors: 3
- Average comments per issue: 0.4
- Average comments per pull request: 1.2
- Merged pull requests: 22
- Bot issues: 0
- Bot pull requests: 0
Past Year
- Issues: 1
- Pull requests: 4
- Average time to close issues: N/A
- Average time to close pull requests: about 2 hours
- Issue authors: 1
- Pull request authors: 1
- Average comments per issue: 0.0
- Average comments per pull request: 1.5
- Merged pull requests: 4
- Bot issues: 0
- Bot pull requests: 0
Top Authors
Issue Authors
- BirkhoffG (24)
- smachar (1)
Pull Request Authors
- BirkhoffG (27)
- Praneyg (3)
- FirdausChoudhury (3)
Top Labels
Issue Labels
Pull Request Labels
Packages
- Total packages: 1
-
Total downloads:
- pypi 79 last-month
- Total dependent packages: 0
- Total dependent repositories: 0
- Total versions: 18
- Total maintainers: 1
pypi.org: jax-relax
JAX-based Recourse Explanation Library
- Homepage: https://github.com/birkhoffg/jax-relax
- Documentation: https://jax-relax.readthedocs.io/
- License: Apache Software License 2.0
-
Latest release: 0.2.9
published about 1 year ago
