tabsplanation
Experiments on counterfactual explanations for neural networks, based on the [latent shift method](https://arxiv.org/abs/2102.09475)
Science Score: 18.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
○codemeta.json file
-
○.zenodo.json file
-
○DOI references
-
○Academic publication links
-
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (16.4%) to scientific vocabulary
Keywords
Repository
Experiments on counterfactual explanations for neural networks, based on the [latent shift method](https://arxiv.org/abs/2102.09475)
Basic Info
Statistics
- Stars: 0
- Watchers: 1
- Forks: 0
- Open Issues: 0
- Releases: 0
Topics
Metadata Files
README.md
tabsplanation
Installation
To get started, create the environment with
console
conda env create -f environment.yml
or
console
mamba create -f environment.yml
Then you'll need to install pytorch-lightning manually (don't ask me why!).
console
mamba activate tabsplanation
pip3 install torch==1.13.1+cu117 -f https://download.pytorch.org/whl/torch_stable.html
pip3 install lightning
pip3 install -e .
If the pip install -e doesn't work automatically, enter the environment and
run it manually. This is the ensure that tabsplanation/src is in sys.path
when pytask is run.
You might also need a working latex install if you want it to be used in the plots.
Usage
In the root of the project, in the environment, run
console
python3 -m src.experiments.run <name-of-experiment>
to run a given experiment. The name is given by the name of the directory in src/experiments.
By default, any output is captured by pytask. Hence, for visualizing the training of
a model, it is recommended to use tensorboard.
In the root of the project, run
console
tensorboard --logdir=bld/models/lightning_logs
and follow the instructions.
Rationale
This repository contains various experiments exploring a technique called Latent Shift, coined by Cohen et al. in their paper "Gifsplanation".
The point of pytask is to offer an elegant way to make data-related
workflows cacheable, i.e. when each individual part of a workflow
produces outputs, these outputs are saved so they can be re-used
later and the workflow can be skipped.
Previously the same codebase was written with hydra in mind: hydra
offers a somewhat intuitive system to parametrize an experiment using
a configuration file, usually written in YAML.
In this case an experiment workflow would be written as follows:
```python
cfg contains all configuration information about the workflow
@hydra.main(config="myconfig.yaml") def myexperiment(cfg): # Step 1 data = create_data(cfg)
# Step 2
if load_models:
models = load_models(path)
else:
models = train_models(cfg, data)
# Step 3
plot_data = create_plot_data(models, data)
# Step 4
plot = create_plot(plot_data)
# Step 5
show_plot(plot)
``
You can see that the "get models" part is cacheable: there is an
option to ask the system to load models from a specific part.
Indeed, of all the steps in the workflow, this step is by far
the most time-consuming. However, the other steps are still
re-run, every time.
What is more, when loading models, the option is currently to
pass a directory path; so when doing this, one must be certain
that the data that was loaded withget_data` is exactly the
same as the one used to train the loaded models; the book-keeping
has to be done manually, which is error-prone and frustrating.
Instead, it would be saner to divide up the workflow into each
step, and let pytask handle the caching:
```python
def taskcreatedataset(dependson, produces):
# Load config
cfg = dependson["config"]
# Create the dataset according to config
data = create_dataset(cfg)
# Cache results
save(data, produces["data"])
``
and similarly for all the other steps.
Now, instead of gettinghydrato runmyexperiment,
you could just askpytaskto run theshowplottask
according to a configuration file, and let _it_ figure
out what needs to be done to make that happen.
In particular, if you ask it to runshow_plot` two
times with the same config, the second run should
be very quick because everything is cached; hence
you can afford to run the whole workflow even if
it's just to tweak the plot visuals.
Credits
This project was created with cookiecutter and the cookiecutter-pytask-project template.
TODO
- The paths to files that are output by a task should be printed to
the console when a task is run. The annoying thing is
pytaskcaptures everything by default, and there not yet enough granularity to surface a particularprint. - There is no facility to allow multiple config files at the same
time; the only option is the change the appropriate
yamlfile every time.
Owner
- Name: Auguste Baum
- Login: augustebaum
- Kind: user
- Repositories: 4
- Profile: https://github.com/augustebaum
Citation (CITATION)
@Unpublished{tabsplanation2023,
Title = {Latent shift applied to tabular data},
Author = {Auguste Baum},
Year = {2023},
Url = {https://github.com/augustebaum/tabsplanation}
}
GitHub Events
Total
Last Year
Dependencies
- lightning *
- ipdb ^0.13.9 develop
- pytest ^7.1.3 develop
- ipykernel ^6.16.0
- matplotlib 3.5
- numpy ^1.23.3
- omegaconf ^2.2.3
- pandas ^1.5.0
- pytask ^0.2.6
- python >=3.10,<3.12
- pytorch-lightning ^1.7.7
- scikit-learn ^1.2.0
- tomli ^2.0.1
- ipywidgets ^8.0.2 vscode
- Jinja2 ==3.1.2
- Pillow ==9.3.0
- PyQt5 ==5.15.7
- PyQt5-sip ==12.11.0
- anyio ==3.6.2
- appdirs ==1.4.4
- arrow ==1.2.3
- beautifulsoup4 ==4.11.1
- blessed ==1.19.1
- blinker ==1.4
- brotlipy ==0.7.0
- croniter ==1.3.8
- dateutils ==0.6.12
- deepdiff ==6.2.3
- dnspython ==2.3.0
- email-validator ==1.3.1
- fastapi ==0.88.0
- fonttools ==4.25.0
- fsspec ==2023.1.0
- h11 ==0.14.0
- httpcore ==0.16.3
- httptools ==0.5.0
- httpx ==0.23.3
- inquirer ==3.1.2
- itsdangerous ==2.1.2
- lightning ==1.9.0
- lightning-cloud ==0.5.19
- lightning-utilities ==0.6.0.post0
- munkres ==1.1.4
- mypy-extensions ==1.0.0
- ordered-set ==4.1.0
- orjson ==3.8.5
- pandas ==1.5.3
- ply ==3.11
- protobuf ==4.21.12
- pyasn1-modules ==0.2.8
- pydantic ==1.10.4
- pyre-extensions ==0.0.30
- pytest ==7.2.1
- python-dotenv ==0.21.1
- python-editor ==1.0.4
- python-multipart ==0.0.5
- readchar ==4.0.3
- requests-oauthlib ==1.3.0
- rfc3986 ==1.5.0
- scipy ==1.10.0
- sniffio ==1.3.0
- soupsieve ==2.3.2.post1
- starlette ==0.22.0
- starsessions ==1.3.0
- torch ==1.13.1
- torcheval ==0.0.6
- torchmetrics ==0.11.0
- torchtnt ==0.0.6
- tqdm ==4.64.1
- typing-inspect ==0.8.0
- ujson ==5.7.0
- uvicorn ==0.20.0
- uvloop ==0.17.0
- watchfiles ==0.18.1
- websocket-client ==1.5.0
- websockets ==10.4