ReLax

ReLax: Efficient and Scalable Recourse Explanation Benchmarking using JAX - Published in JOSS (2024)

https://github.com/birkhoffg/jax-relax

Science Score: 95.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
    Found 4 DOI reference(s) in README and JOSS metadata
  • Academic publication links
    Links to: arxiv.org, joss.theoj.org
  • Committers with academic emails
    1 of 7 committers (14.3%) from academic institutions
  • Institutional organization owner
  • JOSS paper metadata
    Published in Journal of Open Source Software
Last synced: 6 months ago · JSON representation

Repository

Recourse Explanation Library in JAX

Basic Info
Statistics
  • Stars: 6
  • Watchers: 1
  • Forks: 1
  • Open Issues: 6
  • Releases: 10
Created over 2 years ago · Last pushed about 1 year ago
Metadata Files
Readme Contributing License

README.md

ReLax

Python CI
status Docs pypi GitHub
License DOI

Overview | Installation | Tutorials | Documentation | Citing ReLax

Overview

ReLax (Recourse Explanation Library in Jax) is an efficient and scalable benchmarking library for recourse and counterfactual explanations, built on top of jax. By leveraging language primitives such as vectorization, parallelization, and just-in-time compilation in jax, ReLax offers massive speed improvements in generating individual (or local) explanations for predictions made by Machine Learning algorithms.

Some of the key features are as follows:

  • 🏃 Fast and scalable recourse generation.

  • 🚀 Accelerated over cpu, gpu, tpu.

  • 🪓 Comprehensive set of recourse methods implemented for benchmarking.

  • 👐 Customizable API to enable the building of entire modeling and interpretation pipelines for new recourse algorithms.

Installation

``` bash pip install jax-relax

Or install the latest version of jax-relax

pip install git+https://github.com/BirkhoffG/jax-relax.git ```

To futher unleash the power of accelerators (i.e., GPU/TPU), we suggest to first install this library via pip install jax-relax. Then, follow steps in the official install guidelines to install the right version for GPU or TPU.

Dive into ReLax

ReLax is a recourse explanation library for explaining (any) JAX-based ML models. We believe that it is important to give users flexibility to choose how to use ReLax. You can

  • only use methods implemeted in ReLax (as a recourse methods library);
  • build a pipeline using ReLax to define data module, training ML models, and generating CF explanation (for constructing recourse benchmarking pipeline).

ReLax as a Recourse Explanation Library

We introduce basic use cases of using methods in ReLax to generate recourse explanations. For more advanced usages of methods in ReLax, See this tutorials.

python from relax.methods import VanillaCF from relax import DataModule, MLModule, generate_cf_explanations, benchmark_cfs from sklearn.datasets import make_classification from sklearn.model_selection import train_test_split import functools as ft import jax

Let’s first generate synthetic data:

python xs, ys = make_classification(n_samples=1000, n_features=10, random_state=42) train_xs, test_xs, train_ys, test_ys = train_test_split(xs, ys, random_state=42)

Next, we fit an MLP model for this data. Note that this model can be any model implmented in JAX. We will use the MLModule in ReLax as an example.

python model = MLModule() model.train((train_xs, train_ys), epochs=10, batch_size=64)

Generating recourse explanations are straightforward. We can simply call generate_cf of an implemented recourse method to generate one recourse explanation:

python vcf = VanillaCF(config={'n_steps': 1000, 'lr': 0.05}) cf = vcf.generate_cf(test_xs[0], model.pred_fn) assert cf.shape == test_xs[0].shape

Or generate a bunch of recourse explanations with jax.vmap:

python generate_fn = ft.partial(vcf.generate_cf, pred_fn=model.pred_fn) cfs = jax.vmap(generate_fn)(test_xs) assert cfs.shape == test_xs.shape

ReLax for Building Recourse Explanation Pipelines

The above example illustrates the usage of the decoupled relax.methods to generate recourse explanations. However, users are required to write boilerplate code for tasks such as data preprocessing, model training, and generating recourse explanations with feature constraints.

ReLax additionally offers a one-liner framework, streamlining the process and helping users in building a standardized pipeline for generating recourse explanations. You can write three lines of code to benchmark recourse explanations:

python data_module = DataModule.from_numpy(xs, ys) exps = generate_cf_explanations(vcf, data_module, model.pred_fn) benchmark_cfs([exps])

See Getting Started with ReLax for an end-to-end example of using ReLax.

Supported Recourse Methods

ReLax currently provides implementations of 9 recourse explanation methods.

| Method | Type | Paper Title | Ref | |--------------------------------------------------------------------------------------------|-----------------|------------------------------------------------------------------------------------------------|-------------------------------------------| | VanillaCF | Non-Parametric | Counterfactual Explanations without Opening the Black Box: Automated Decisions and the GDPR. | [1] | | DiverseCF | Non-Parametric | Explaining Machine Learning Classifiers through Diverse Counterfactual Explanations. | [2] | | ProtoCF | Semi-Parametric | Interpretable Counterfactual Explanations Guided by Prototypes. | [3] | | CounterNet | Parametric | CounterNet: End-to-End Training of Prediction Aware Counterfactual Explanations. | [4] | | GrowingSphere | Non-Parametric | Inverse Classification for Comparison-based Interpretability in Machine Learning. | [5] | | CCHVAE | Semi-Parametric | Learning Model-Agnostic Counterfactual Explanations for Tabular Data. | [6] | | VAECF | Parametric | Preserving Causal Constraints in Counterfactual Explanations for Machine Learning Classifiers. | [7] | | CLUE | Semi-Parametric | Getting a CLUE: A Method for Explaining Uncertainty Estimates. | [8] | | L2C | Parametric | Feature-based Learning for Diverse and Privacy-Preserving Counterfactual Explanations | [9] |

Citing ReLax

To cite this repository:

latex @software{relax2023github, author = {Hangzhi Guo and Xinchang Xiong and Amulya Yadav}, title = {{R}e{L}ax: Recourse Explanation Library in Jax}, url = {http://github.com/birkhoffg/jax-relax}, version = {0.2.0}, year = {2023}, }

Owner

  • Name: Hangzhi Guo
  • Login: BirkhoffG
  • Kind: user
  • Company: Penn State University

Ph.D. Student at Penn State University

JOSS Publication

ReLax: Efficient and Scalable Recourse Explanation Benchmarking using JAX
Published
November 12, 2024
Volume 9, Issue 103, Page 6567
Authors
Hangzhi Guo ORCID
Penn State University, University Park, PA, USA
Xinchang Xiong
Duke University, Durham, NC, USA
Wenbo Zhang
Penn State University, University Park, PA, USA
Amulya Yadav ORCID
Penn State University, University Park, PA, USA
Editor
Fei Tao ORCID
Tags
JAX machine learning interpretability counterfactual explanation recourse

GitHub Events

Total
  • Create event: 3
  • Release event: 1
  • Issues event: 1
  • Issue comment event: 2
  • Push event: 14
  • Pull request event: 3
Last Year
  • Create event: 3
  • Release event: 1
  • Issues event: 1
  • Issue comment event: 2
  • Push event: 14
  • Pull request event: 3

Committers

Last synced: 7 months ago

All Time
  • Total Commits: 456
  • Total Committers: 7
  • Avg Commits per committer: 65.143
  • Development Distribution Score (DDS): 0.147
Past Year
  • Commits: 18
  • Committers: 1
  • Avg Commits per committer: 18.0
  • Development Distribution Score (DDS): 0.0
Top Committers
Name Email Commits
BirkhoffG 2****G 389
Xinchang Xiong 6****e 29
Xinchang Xiong c****k@P****l 29
Firdaus Choudhury f****6@p****u 5
Praneyg p****3@g****m 2
Xinchang Xiong 6****B 1
root r****t@i****l 1
Committer Domains (Top 20 + Academic)

Issues and Pull Requests

Last synced: 6 months ago

All Time
  • Total issues: 25
  • Total pull requests: 25
  • Average time to close issues: 20 days
  • Average time to close pull requests: 12 days
  • Total issue authors: 2
  • Total pull request authors: 3
  • Average comments per issue: 0.4
  • Average comments per pull request: 1.2
  • Merged pull requests: 22
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 1
  • Pull requests: 4
  • Average time to close issues: N/A
  • Average time to close pull requests: about 2 hours
  • Issue authors: 1
  • Pull request authors: 1
  • Average comments per issue: 0.0
  • Average comments per pull request: 1.5
  • Merged pull requests: 4
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • BirkhoffG (24)
  • smachar (1)
Pull Request Authors
  • BirkhoffG (27)
  • Praneyg (3)
  • FirdausChoudhury (3)
Top Labels
Issue Labels
enhancement (13) bug (4) documentation (2) method (1)
Pull Request Labels
enhancement (11) bug (10) documentation (3) breaking changes (2)

Packages

  • Total packages: 1
  • Total downloads:
    • pypi 79 last-month
  • Total dependent packages: 0
  • Total dependent repositories: 0
  • Total versions: 18
  • Total maintainers: 1
pypi.org: jax-relax

JAX-based Recourse Explanation Library

  • Versions: 18
  • Dependent Packages: 0
  • Dependent Repositories: 0
  • Downloads: 79 Last month
Rankings
Dependent packages count: 6.6%
Downloads: 17.9%
Average: 22.1%
Forks count: 23.2%
Dependent repos count: 30.6%
Stargazers count: 32.3%
Maintainers (1)
Last synced: 6 months ago