tabsplanation

Experiments on counterfactual explanations for neural networks, based on the [latent shift method](https://arxiv.org/abs/2102.09475)

https://github.com/augustebaum/tabsplanation

Science Score: 18.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
  • .zenodo.json file
  • DOI references
  • Academic publication links
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (16.4%) to scientific vocabulary

Keywords

explainable-ai gradient pytask
Last synced: 6 months ago · JSON representation ·

Repository

Experiments on counterfactual explanations for neural networks, based on the [latent shift method](https://arxiv.org/abs/2102.09475)

Basic Info
  • Host: GitHub
  • Owner: augustebaum
  • License: mit
  • Language: Python
  • Default Branch: main
  • Homepage:
  • Size: 698 KB
Statistics
  • Stars: 0
  • Watchers: 1
  • Forks: 0
  • Open Issues: 0
  • Releases: 0
Topics
explainable-ai gradient pytask
Created about 3 years ago · Last pushed almost 3 years ago
Metadata Files
Readme License Citation

README.md

tabsplanation

pre-commit.ci status image

Installation

To get started, create the environment with console conda env create -f environment.yml or console mamba create -f environment.yml Then you'll need to install pytorch-lightning manually (don't ask me why!). console mamba activate tabsplanation pip3 install torch==1.13.1+cu117 -f https://download.pytorch.org/whl/torch_stable.html pip3 install lightning pip3 install -e . If the pip install -e doesn't work automatically, enter the environment and run it manually. This is the ensure that tabsplanation/src is in sys.path when pytask is run.

You might also need a working latex install if you want it to be used in the plots.

Usage

In the root of the project, in the environment, run console python3 -m src.experiments.run <name-of-experiment> to run a given experiment. The name is given by the name of the directory in src/experiments.

By default, any output is captured by pytask. Hence, for visualizing the training of a model, it is recommended to use tensorboard. In the root of the project, run console tensorboard --logdir=bld/models/lightning_logs and follow the instructions.

Rationale

This repository contains various experiments exploring a technique called Latent Shift, coined by Cohen et al. in their paper "Gifsplanation".

The point of pytask is to offer an elegant way to make data-related workflows cacheable, i.e. when each individual part of a workflow produces outputs, these outputs are saved so they can be re-used later and the workflow can be skipped.

Previously the same codebase was written with hydra in mind: hydra offers a somewhat intuitive system to parametrize an experiment using a configuration file, usually written in YAML. In this case an experiment workflow would be written as follows: ```python

cfg contains all configuration information about the workflow

@hydra.main(config="myconfig.yaml") def myexperiment(cfg): # Step 1 data = create_data(cfg)

# Step 2
if load_models:
    models = load_models(path)
else:
    models = train_models(cfg, data)

# Step 3
plot_data = create_plot_data(models, data)

# Step 4
plot = create_plot(plot_data)

# Step 5
show_plot(plot)

`` You can see that the "get models" part is cacheable: there is an option to ask the system to load models from a specific part. Indeed, of all the steps in the workflow, this step is by far the most time-consuming. However, the other steps are still re-run, every time. What is more, when loading models, the option is currently to pass a directory path; so when doing this, one must be certain that the data that was loaded withget_data` is exactly the same as the one used to train the loaded models; the book-keeping has to be done manually, which is error-prone and frustrating.

Instead, it would be saner to divide up the workflow into each step, and let pytask handle the caching: ```python def taskcreatedataset(dependson, produces): # Load config cfg = dependson["config"]

# Create the dataset according to config
data = create_dataset(cfg)

# Cache results
save(data, produces["data"])

`` and similarly for all the other steps. Now, instead of gettinghydrato runmyexperiment, you could just askpytaskto run theshowplottask according to a configuration file, and let _it_ figure out what needs to be done to make that happen. In particular, if you ask it to runshow_plot` two times with the same config, the second run should be very quick because everything is cached; hence you can afford to run the whole workflow even if it's just to tweak the plot visuals.

Credits

This project was created with cookiecutter and the cookiecutter-pytask-project template.

TODO

  • The paths to files that are output by a task should be printed to the console when a task is run. The annoying thing is pytask captures everything by default, and there not yet enough granularity to surface a particular print.
  • There is no facility to allow multiple config files at the same time; the only option is the change the appropriate yaml file every time.

Owner

  • Name: Auguste Baum
  • Login: augustebaum
  • Kind: user

Citation (CITATION)

@Unpublished{tabsplanation2023,
    Title  = {Latent shift applied to tabular data},
    Author = {Auguste Baum},
    Year   = {2023},
    Url    = {https://github.com/augustebaum/tabsplanation}
}

GitHub Events

Total
Last Year

Dependencies

environment.yml pypi
  • lightning *
pyproject.toml pypi
  • ipdb ^0.13.9 develop
  • pytest ^7.1.3 develop
  • ipykernel ^6.16.0
  • matplotlib 3.5
  • numpy ^1.23.3
  • omegaconf ^2.2.3
  • pandas ^1.5.0
  • pytask ^0.2.6
  • python >=3.10,<3.12
  • pytorch-lightning ^1.7.7
  • scikit-learn ^1.2.0
  • tomli ^2.0.1
  • ipywidgets ^8.0.2 vscode
requirements.txt pypi
  • Jinja2 ==3.1.2
  • Pillow ==9.3.0
  • PyQt5 ==5.15.7
  • PyQt5-sip ==12.11.0
  • anyio ==3.6.2
  • appdirs ==1.4.4
  • arrow ==1.2.3
  • beautifulsoup4 ==4.11.1
  • blessed ==1.19.1
  • blinker ==1.4
  • brotlipy ==0.7.0
  • croniter ==1.3.8
  • dateutils ==0.6.12
  • deepdiff ==6.2.3
  • dnspython ==2.3.0
  • email-validator ==1.3.1
  • fastapi ==0.88.0
  • fonttools ==4.25.0
  • fsspec ==2023.1.0
  • h11 ==0.14.0
  • httpcore ==0.16.3
  • httptools ==0.5.0
  • httpx ==0.23.3
  • inquirer ==3.1.2
  • itsdangerous ==2.1.2
  • lightning ==1.9.0
  • lightning-cloud ==0.5.19
  • lightning-utilities ==0.6.0.post0
  • munkres ==1.1.4
  • mypy-extensions ==1.0.0
  • ordered-set ==4.1.0
  • orjson ==3.8.5
  • pandas ==1.5.3
  • ply ==3.11
  • protobuf ==4.21.12
  • pyasn1-modules ==0.2.8
  • pydantic ==1.10.4
  • pyre-extensions ==0.0.30
  • pytest ==7.2.1
  • python-dotenv ==0.21.1
  • python-editor ==1.0.4
  • python-multipart ==0.0.5
  • readchar ==4.0.3
  • requests-oauthlib ==1.3.0
  • rfc3986 ==1.5.0
  • scipy ==1.10.0
  • sniffio ==1.3.0
  • soupsieve ==2.3.2.post1
  • starlette ==0.22.0
  • starsessions ==1.3.0
  • torch ==1.13.1
  • torcheval ==0.0.6
  • torchmetrics ==0.11.0
  • torchtnt ==0.0.6
  • tqdm ==4.64.1
  • typing-inspect ==0.8.0
  • ujson ==5.7.0
  • uvicorn ==0.20.0
  • uvloop ==0.17.0
  • watchfiles ==0.18.1
  • websocket-client ==1.5.0
  • websockets ==10.4