crm
Compositional Relational Machines (CRMs): Constructing deep neural networks that are logically explainable by design
Science Score: 49.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
○CITATION.cff file
-
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
✓DOI references
Found 5 DOI reference(s) in README -
✓Academic publication links
Links to: arxiv.org -
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (11.3%) to scientific vocabulary
Keywords
Repository
Compositional Relational Machines (CRMs): Constructing deep neural networks that are logically explainable by design
Basic Info
- Host: GitHub
- Owner: tirtharajdash
- License: mit
- Language: Perl
- Default Branch: main
- Homepage: https://doi.org/10.1007/s10994-023-06399-6
- Size: 76.6 MB
Statistics
- Stars: 0
- Watchers: 1
- Forks: 1
- Open Issues: 0
- Releases: 0
Topics
Metadata Files
README.md
Compositional Relational Machine (CRM)
Code and Data repository for our paper titled Composition of Relational Features with an Application to Explaining Black-Box Predictors.
The paper is under review (status: Revision submitted) at the Machine Learning Journal (MLJ). We will update the link to the paper once the paper is accepted and published officially.
Our repository of CRM here might not be maintained. The current repository is intended to be up-to-date with our newer results and further research extensions. Keep watching!
[Nov 14, 2023] This paper was also presented at the IJCLR 2023: PPT.\ [Oct 26, 2023] The paper is now officially published by the Machine Learning Journal.
What is a CRM?
CRMs are explainable deep neural networks, a neurosymbolic architecture in which each node in a neural network has an associated relational feature. While being independently explainable in nature, CRMs can also be used for generating structured proxy explanations for a black-box predictor.
A CRM consists of CRM-nodes (or CRM-neurons), where each node is associated with a valuation of a relational feature. An illustrative figure is shown below:
Every neuron has an arithmetic part $\Sigma$ and a logical part $fi(a)$. The arithmetic part acts as it would in any standard neural network, it is formed by $w{ji}hj(aj) + w{ki}hk(ak)$, the logical part is what is derived is a deterministic function of the data instance $x$. Hence, for a neuron($ni$) with two ancestor neurons($nj, nk$) the final output would be
$$ hi(ai) = gi(w{ji}hj(aj) + w{ki}hk(ak)) \times fi(a) $$
Let us consider an example problem from the paper to illustrate the workings of a CRM.
The Train Problem
The "East-West Challenge" (Michie et al., 1994), participants were supposed to distinguish eastbound trains from westbound ones using properties of the carriages and their loads (the engines properties are not used), using pictorial descriptions like these (T1 is eastbound and T2 is westbound):
Our CRM aims to take these features as input, predict the direction of the train, and generate explanations for the prediction.
The first layer of the CRM is composed of primitive features, like
math
\begin{align*}
C_1 &: p(X) \leftarrow (has\_car(X,Y), short(Y)) \\
C_2 &: p(X) \leftarrow (has\_car(X,Y), closed(Y)) \\
C_3 &: p(X) \leftarrow (has\_car(X,Y), short(Y), closed(Y)) \\
C_4 &: p(X) \leftarrow (has\_car(X,Y), has\_car(X,Z), short(Y), closed(Z))
\end{align*}
Here, we will assume that predicates like $has_car/2$, $short/1$, $closed/1$, $long/1$, $has_load/3$, etc., are defined as part of the background $B$, and capture the situation shown diagrammatically. That is,
math
\begin{align*}
B = \{~ & has\_car(t1,c1\_1), has\_car(t1,c1\_2), \ldots \\
& long(c1\_1), closed(c1\_1), has\_load(c1\_1,square,3), \ldots \\
& has\_car(t2,c2\_1), has\_car(t2,c2\_2), \ldots ~\}.
\end{align*}
Then the corresponding feature-function values are:
$$ \begin{align} f{C1,B}(t1) = f1(t1) = 1; &~ f1(t2) = 1; \ f{C2,B}(t1) = f2(t1) = 1; &~ f2(t2) = 1; \ f{C3,B}(t1) = f3(t1) = 1; &~ f3(t2) = 0; \ f{C4,B}(t1) = f4(t1) = 1; &~ f4(t2) = 1. \end{align} $$
Given a way to generate the predicate part of the CRM, the arithmetic part is generated using the same way as in a standard neural network using backpropagation, with each neuron ($ni$) being _dropped out based on its predicate function ($f_i$).
Explanations Generated by CRM
Given a (relational) data instance representing a train with some cars and their properties, as shown below, a CRM generates structured explanations shown below.
We use Layerwise-Relevance Propagation (LRP) (Bach et al., 2015; Binder et al., 2016) to calculate the most relevant neurons in the penultimate layer of the CRM, just before the final two neuron classification layer. We then find the ancestral graph of this neuron to explain the prediction.
Explaination ancestral graph
Authors
- CRM's core codebase is developed jointly with Devanshu Shah @Devanshu24.
How to run our code
Environment setup: requirements.txt
```console $ python3 main.py -h usage: main.py [-h] -f FILE -o OUTPUT -n NUM_EPOCHS [-s SAVED_MODEL] [-e] [-v] [-g]
CRM; Example: python3 main.py -f inp.file -o out.file -n 20 -s saved_model.pt -e -v
optional arguments: -h, --help show this help message and exit -f FILE, --file FILE input file -o OUTPUT, --output OUTPUT output file -n NUMEPOCHS, --num-epochs NUMEPOCHS number of epochs -s SAVEDMODEL, --saved-model SAVEDMODEL location of saved model -e, --explain get explanations for predictions -v, --verbose get verbose outputs -g, --gpu run model on gpu
```
For reference, some example commands are in command.hist.
Caution! Some result or model files might get overwritten. Check the directory structure before running. Honestly, we did not properly care about this aspect. Over time, we will fix this simple thing. Also, our results get saved in the ./data directory, which is not good. We will work on these :-)
How to cite our paper
Bibtex:
@article{srinivasan2023composition,
title={Composition of relational features with an application to explaining black-box predictors},
author={Srinivasan, Ashwin and Baskar, A and Dash, Tirtharaj and Shah, Devanshu},
journal={Machine Learning},
pages={1--42},
year={2023},
doi={10.1007/s10994-023-06399-6},
publisher={Springer},
url={https://doi.org/10.1007/s10994-023-06399-6}
}
Owner
- Name: Tirtharaj Dash
- Login: tirtharajdash
- Kind: user
- Location: California, USA
- Company: University of California, San Diego
- Website: https://tirtharajdash.github.io/
- Repositories: 4
- Profile: https://github.com/tirtharajdash
I work on Deep Learning, Neuro-Symbolic Models, Graph Representation Learning, and Machine Learning for Science.
GitHub Events
Total
Last Year
Dependencies
- absl-py =1.1.0=pypi_0
- aiohttp =3.8.1=pypi_0
- aiosignal =1.2.0=pypi_0
- alembic =1.7.7=pypi_0
- antlr4-python3-runtime =4.9.3=pypi_0
- anyio =2.2.0=py38h06a4308_1
- argon2-cffi =20.1.0=py38h27cfd23_1
- async-timeout =4.0.2=pypi_0
- async_generator =1.10=pyhd3eb1b0_0
- attrs =21.2.0=pyhd3eb1b0_0
- autopage =0.5.0=pypi_0
- autopep8 =1.6.0=pypi_0
- babel =2.9.1=pyhd3eb1b0_0
- backcall =0.2.0=pyhd3eb1b0_0
- backports-entry-points-selectable =1.1.1=pypi_0
- bleach =4.0.0=pyhd3eb1b0_0
- brotlipy =0.7.0=py38h27cfd23_1003
- ca-certificates =2021.10.26=h06a4308_2
- cachetools =5.0.0=pypi_0
- captum =0.4.1=pypi_0
- certifi =2021.10.8=py38h06a4308_0
- cffi =1.14.6=py38h400218f_0
- cfgv =3.3.1=pypi_0
- charset-normalizer =2.0.4=pyhd3eb1b0_0
- click =8.0.4=pypi_0
- cliff =3.10.1=pypi_0
- cloudpickle =2.1.0=pypi_0
- cmaes =0.8.2=pypi_0
- cmd2 =2.4.0=pypi_0
- colorlog =6.6.0=pypi_0
- cryptography =35.0.0=py38hd23ed53_0
- cycler =0.11.0=pypi_0
- dataclasses =0.6=pypi_0
- debugpy =1.5.1=py38h295c915_0
- decorator =5.1.0=pyhd3eb1b0_0
- defusedxml =0.7.1=pyhd3eb1b0_0
- deprecated =1.2.13=pypi_0
- dill =0.3.4=pypi_0
- distlib =0.3.3=pypi_0
- entrypoints =0.3=py38_0
- filelock =3.4.0=pypi_0
- frozenlist =1.3.0=pypi_0
- fsspec =2022.5.0=pypi_0
- future =0.18.2=pypi_0
- google-auth =2.8.0=pypi_0
- google-auth-oauthlib =0.4.6=pypi_0
- greenlet =1.1.2=pypi_0
- grpcio =1.43.0=pypi_0
- hydra-core =1.2.0=pypi_0
- identify =2.4.0=pypi_0
- idna =3.2=pyhd3eb1b0_0
- importlib-metadata =4.8.1=py38h06a4308_0
- importlib-resources =5.4.0=pypi_0
- importlib_metadata =4.8.1=hd3eb1b0_0
- iniconfig =1.1.1=pypi_0
- ipykernel =6.4.1=py38h06a4308_1
- ipython =7.29.0=py38hb070fc8_0
- ipython_genutils =0.2.0=pyhd3eb1b0_1
- ipywidgets =7.6.5=pypi_0
- jedi =0.18.0=py38h06a4308_1
- jinja2 =3.0.2=pyhd3eb1b0_0
- joblib =1.1.0=pypi_0
- json5 =0.9.6=pyhd3eb1b0_0
- jsonschema =3.2.0=pyhd3eb1b0_2
- jupyter-http-over-ws =0.0.8=pypi_0
- jupyter_client =7.0.6=pyhd3eb1b0_0
- jupyter_core =4.9.1=py38h06a4308_0
- jupyter_server =1.4.1=py38h06a4308_0
- jupyterlab =3.2.1=pyhd8ed1ab_0
- jupyterlab-widgets =1.0.2=pypi_0
- jupyterlab_pygments =0.1.2=py_0
- jupyterlab_server =2.8.2=pyhd3eb1b0_0
- kiwisolver =1.3.2=pypi_0
- ld_impl_linux-64 =2.35.1=h7274673_9
- libffi =3.3=he6710b0_2
- libgcc-ng =9.3.0=h5101ec6_17
- libgomp =9.3.0=h5101ec6_17
- libsodium =1.0.18=h7b6447c_0
- libstdcxx-ng =9.3.0=hd4cf53a_17
- mako =1.2.0=pypi_0
- markdown =3.3.7=pypi_0
- markupsafe =2.0.1=py38h27cfd23_0
- matplotlib =3.4.3=pypi_0
- matplotlib-inline =0.1.2=pyhd3eb1b0_2
- mistune =0.8.4=py38h7b6447c_1000
- msgpack =1.0.3=pypi_0
- multidict =6.0.2=pypi_0
- nbclassic =0.2.6=pyhd3eb1b0_0
- nbclient =0.5.3=pyhd3eb1b0_0
- nbconvert =6.1.0=py38h06a4308_0
- nbformat =5.1.3=pyhd3eb1b0_0
- ncurses =6.3=h7f8727e_2
- nest-asyncio =1.5.1=pyhd3eb1b0_0
- networkx =2.6.3=pypi_0
- nodeenv =1.6.0=pypi_0
- notebook =6.4.6=py38h06a4308_0
- numpy =1.21.4=pypi_0
- nvidia-ml-py =11.450.51=pypi_0
- nvitop =0.5.3=pypi_0
- oauthlib =3.2.0=pypi_0
- omegaconf =2.2.2=pypi_0
- openssl =1.1.1l=h7f8727e_0
- optuna =2.10.0=pypi_0
- packaging =21.2=pypi_0
- pandas =1.3.4=pypi_0
- pandocfilters =1.4.3=py38h06a4308_1
- parso =0.8.2=pyhd3eb1b0_0
- pbr =5.8.1=pypi_0
- pexpect =4.8.0=pyhd3eb1b0_3
- pickleshare =0.7.5=pyhd3eb1b0_1003
- pillow =8.4.0=pypi_0
- pip =22.1.2=pypi_0
- platformdirs =2.4.0=pypi_0
- pluggy =1.0.0=pypi_0
- pre-commit =2.15.0=pypi_0
- prettytable =3.2.0=pypi_0
- prometheus_client =0.12.0=pyhd3eb1b0_0
- prompt-toolkit =3.0.20=pyhd3eb1b0_0
- protobuf =3.19.4=pypi_0
- psutil =5.9.0=pypi_0
- ptyprocess =0.7.0=pyhd3eb1b0_2
- py =1.10.0=pypi_0
- pyasn1 =0.4.8=pypi_0
- pyasn1-modules =0.2.8=pypi_0
- pycodestyle =2.8.0=pypi_0
- pycparser =2.21=pyhd3eb1b0_0
- pydeprecate =0.3.0=pypi_0
- pygments =2.10.0=pyhd3eb1b0_0
- pyopenssl =21.0.0=pyhd3eb1b0_1
- pyparsing =2.4.7=pypi_0
- pyperclip =1.8.2=pypi_0
- pyrsistent =0.18.0=py38heee7806_0
- pysocks =1.7.1=py38h06a4308_0
- pytest =6.2.5=pypi_0
- python =3.8.12=h12debd9_0
- python-dateutil =2.8.2=pyhd3eb1b0_0
- pytorch-lightning =1.6.4=pypi_0
- pytz =2021.3=pyhd3eb1b0_0
- pyyaml =5.4.1=pypi_0
- pyzmq =22.2.1=py38h295c915_1
- ray =1.11.0=pypi_0
- readline =8.1=h27cfd23_0
- redis =4.1.4=pypi_0
- requests =2.26.0=pyhd3eb1b0_0
- requests-oauthlib =1.3.1=pypi_0
- rsa =4.8=pypi_0
- scikit-learn =1.0.1=pypi_0
- scipy =1.7.2=pypi_0
- seaborn =0.11.2=pypi_0
- send2trash =1.8.0=pyhd3eb1b0_1
- setuptools =58.0.4=py38h06a4308_0
- six =1.16.0=pyhd3eb1b0_0
- sniffio =1.2.0=py38h06a4308_1
- sqlalchemy =1.4.32=pypi_0
- sqlite =3.36.0=hc218d9a_0
- stevedore =3.5.0=pypi_0
- tabulate =0.8.9=pypi_0
- tensorboard =2.9.1=pypi_0
- tensorboard-data-server =0.6.1=pypi_0
- tensorboard-plugin-wit =1.8.1=pypi_0
- tensorboardx =2.5=pypi_0
- termcolor =1.1.0=pypi_0
- terminado =0.9.4=py38h06a4308_0
- testpath =0.5.0=pyhd3eb1b0_0
- threadpoolctl =3.0.0=pypi_0
- tk =8.6.11=h1ccaba5_0
- toml =0.10.2=pypi_0
- torch =1.8.1
- torchaudio =0.8.1=pypi_0
- torchmetrics =0.9.1=pypi_0
- torchvision =0.9.1
- tornado =6.1=py38h27cfd23_0
- tqdm =4.62.3=pypi_0
- traitlets =5.1.1=pyhd3eb1b0_0
- typing-extensions =4.0.0=pypi_0
- urllib3 =1.26.7=pyhd3eb1b0_0
- virtualenv =20.10.0=pypi_0
- wcwidth =0.2.5=pyhd3eb1b0_0
- webencodings =0.5.1=py38_1
- werkzeug =2.1.2=pypi_0
- wheel =0.37.0=pyhd3eb1b0_1
- widgetsnbextension =3.5.2=pypi_0
- wrapt =1.14.0=pypi_0
- xz =5.2.5=h7b6447c_0
- yarl =1.7.2=pypi_0
- zeromq =4.3.4=h2531618_0
- zipp =3.6.0=pyhd3eb1b0_0
- zlib =1.2.11=h7b6447c_3
- aiosignal =1.3.1=pypi_0
- anyio =2.2.0=py38h06a4308_1
- argon2-cffi =20.1.0=py38h27cfd23_1
- async_generator =1.10=pyhd3eb1b0_0
- attrs =21.2.0=pyhd3eb1b0_0
- babel =2.9.1=pyhd3eb1b0_0
- backcall =0.2.0=pyhd3eb1b0_0
- blas =1.0=mkl
- bleach =4.0.0=pyhd3eb1b0_0
- brotlipy =0.7.0=py38h27cfd23_1003
- ca-certificates =2023.01.10=h06a4308_0
- certifi =2022.12.7=py38h06a4308_0
- cffi =1.14.6=py38h400218f_0
- charset-normalizer =2.0.4=pyhd3eb1b0_0
- click =8.1.3=pypi_0
- cryptography =35.0.0=py38hd23ed53_0
- cudatoolkit =10.1.243=h6bb024c_0
- debugpy =1.5.1=py38h295c915_0
- decorator =5.1.0=pyhd3eb1b0_0
- defusedxml =0.7.1=pyhd3eb1b0_0
- dill =0.3.6=pypi_0
- distlib =0.3.6=pypi_0
- entrypoints =0.3=py38_0
- filelock =3.12.0=pypi_0
- freetype =2.11.0=h70c0345_0
- frozenlist =1.3.3=pypi_0
- giflib =5.2.1=h7b6447c_0
- grpcio =1.51.3=pypi_0
- idna =3.2=pyhd3eb1b0_0
- importlib-metadata =4.8.1=py38h06a4308_0
- importlib_metadata =4.8.1=hd3eb1b0_0
- intel-openmp =2021.4.0=h06a4308_3561
- ipykernel =6.4.1=py38h06a4308_1
- ipython =7.29.0=py38hb070fc8_0
- ipython_genutils =0.2.0=pyhd3eb1b0_1
- jedi =0.18.0=py38h06a4308_1
- jinja2 =3.0.2=pyhd3eb1b0_0
- joblib =1.2.0=pypi_0
- jpeg =9e=h7f8727e_0
- json5 =0.9.6=pyhd3eb1b0_0
- jsonschema =3.2.0=pyhd3eb1b0_2
- jupyter_client =7.0.6=pyhd3eb1b0_0
- jupyter_core =4.9.1=py38h06a4308_0
- jupyter_server =1.4.1=py38h06a4308_0
- jupyterlab =3.2.1=pyhd8ed1ab_0
- jupyterlab_pygments =0.1.2=py_0
- jupyterlab_server =2.8.2=pyhd3eb1b0_0
- lcms2 =2.12=h3be6417_0
- ld_impl_linux-64 =2.35.1=h7274673_9
- libffi =3.3=he6710b0_2
- libgcc-ng =9.3.0=h5101ec6_17
- libgomp =9.3.0=h5101ec6_17
- libpng =1.6.37=hbc83047_0
- libsodium =1.0.18=h7b6447c_0
- libstdcxx-ng =9.3.0=hd4cf53a_17
- libtiff =4.2.0=h85742a9_0
- libuv =1.40.0=h7b6447c_0
- libwebp =1.2.2=h55f646e_0
- libwebp-base =1.2.2=h7f8727e_0
- lz4-c =1.9.3=h295c915_1
- markupsafe =2.0.1=py38h27cfd23_0
- matplotlib-inline =0.1.2=pyhd3eb1b0_2
- mistune =0.8.4=py38h7b6447c_1000
- mkl =2021.4.0=h06a4308_640
- mkl-service =2.4.0=py38h7f8727e_0
- mkl_fft =1.3.1=py38hd3c417c_0
- mkl_random =1.2.2=py38h51133e4_0
- msgpack =1.0.5=pypi_0
- nbclassic =0.2.6=pyhd3eb1b0_0
- nbclient =0.5.3=pyhd3eb1b0_0
- nbconvert =6.1.0=py38h06a4308_0
- nbformat =5.1.3=pyhd3eb1b0_0
- ncurses =6.3=h7f8727e_2
- nest-asyncio =1.5.1=pyhd3eb1b0_0
- ninja =1.10.2=h06a4308_5
- ninja-base =1.10.2=hd09550d_5
- notebook =6.4.6=py38h06a4308_0
- numpy =1.22.3=py38he7a7128_0
- numpy-base =1.22.3=py38hf524024_0
- openssl =1.1.1t=h7f8727e_0
- packaging =23.1=pyhd8ed1ab_0
- pandas =2.0.1=pypi_0
- pandocfilters =1.4.3=py38h06a4308_1
- parso =0.8.2=pyhd3eb1b0_0
- pexpect =4.8.0=pyhd3eb1b0_3
- pickleshare =0.7.5=pyhd3eb1b0_1003
- pillow =9.0.1=py38h22f2fdc_0
- pip =23.1.2=pyhd8ed1ab_0
- platformdirs =3.5.0=pypi_0
- prometheus_client =0.12.0=pyhd3eb1b0_0
- prompt-toolkit =3.0.20=pyhd3eb1b0_0
- protobuf =3.20.3=pypi_0
- ptyprocess =0.7.0=pyhd3eb1b0_2
- pycparser =2.21=pyhd3eb1b0_0
- pygments =2.10.0=pyhd3eb1b0_0
- pyopenssl =21.0.0=pyhd3eb1b0_1
- pyrsistent =0.18.0=py38heee7806_0
- pysocks =1.7.1=py38h06a4308_0
- python =3.8.12=h12debd9_0
- python-dateutil =2.8.2=pyhd3eb1b0_0
- pytorch =1.7.0=py3.8_cuda10.1.243_cudnn7.6.3_0
- pytz =2021.3=pyhd3eb1b0_0
- pyyaml =6.0=pypi_0
- pyzmq =22.2.1=py38h295c915_1
- ray =2.4.0=pypi_0
- readline =8.1=h27cfd23_0
- requests =2.26.0=pyhd3eb1b0_0
- scikit-learn =1.0.1=pypi_0
- scipy =1.10.1=pypi_0
- send2trash =1.8.0=pyhd3eb1b0_1
- setuptools =58.0.4=py38h06a4308_0
- six =1.16.0=pyhd3eb1b0_0
- sniffio =1.2.0=py38h06a4308_1
- sqlite =3.36.0=hc218d9a_0
- tabulate =0.9.0=pypi_0
- tensorboardx =2.6=pypi_0
- terminado =0.9.4=py38h06a4308_0
- testpath =0.5.0=pyhd3eb1b0_0
- threadpoolctl =3.1.0=pypi_0
- tk =8.6.11=h1ccaba5_0
- torchaudio =0.7.0=py38
- torchvision =0.8.0=py38_cu101
- tornado =6.1=py38h27cfd23_0
- tqdm =4.65.0=pypi_0
- traitlets =5.1.1=pyhd3eb1b0_0
- typing_extensions =4.5.0=py38h06a4308_0
- tzdata =2023.3=pypi_0
- urllib3 =1.26.7=pyhd3eb1b0_0
- virtualenv =20.21.0=pypi_0
- wcwidth =0.2.5=pyhd3eb1b0_0
- webencodings =0.5.1=py38_1
- wheel =0.37.0=pyhd3eb1b0_1
- xz =5.2.5=h7b6447c_0
- zeromq =4.3.4=h2531618_0
- zipp =3.6.0=pyhd3eb1b0_0
- zlib =1.2.11=h7b6447c_3
- zstd =1.4.9=haebb681_0