crm

Compositional Relational Machines (CRMs): Constructing deep neural networks that are logically explainable by design

https://github.com/tirtharajdash/crm

Science Score: 49.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
    Found 5 DOI reference(s) in README
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (11.3%) to scientific vocabulary

Keywords

deep-learning deep-neural-networks drug-discovery explainable-ai logic-programming
Last synced: 6 months ago · JSON representation

Repository

Compositional Relational Machines (CRMs): Constructing deep neural networks that are logically explainable by design

Basic Info
Statistics
  • Stars: 0
  • Watchers: 1
  • Forks: 1
  • Open Issues: 0
  • Releases: 0
Topics
deep-learning deep-neural-networks drug-discovery explainable-ai logic-programming
Created almost 3 years ago · Last pushed over 2 years ago
Metadata Files
Readme License Citation

README.md

Compositional Relational Machine (CRM)

CRM

Code License Code style: black arXiv

Code and Data repository for our paper titled Composition of Relational Features with an Application to Explaining Black-Box Predictors.

The paper is under review (status: Revision submitted) at the Machine Learning Journal (MLJ). We will update the link to the paper once the paper is accepted and published officially.

Our repository of CRM here might not be maintained. The current repository is intended to be up-to-date with our newer results and further research extensions. Keep watching!

[Nov 14, 2023] This paper was also presented at the IJCLR 2023: PPT.\ [Oct 26, 2023] The paper is now officially published by the Machine Learning Journal.

What is a CRM?

CRMs are explainable deep neural networks, a neurosymbolic architecture in which each node in a neural network has an associated relational feature. While being independently explainable in nature, CRMs can also be used for generating structured proxy explanations for a black-box predictor.

A CRM consists of CRM-nodes (or CRM-neurons), where each node is associated with a valuation of a relational feature. An illustrative figure is shown below:

Every neuron has an arithmetic part $\Sigma$ and a logical part $fi(a)$. The arithmetic part acts as it would in any standard neural network, it is formed by $w{ji}hj(aj) + w{ki}hk(ak)$, the logical part is what is derived is a deterministic function of the data instance $x$. Hence, for a neuron($ni$) with two ancestor neurons($nj, nk$) the final output would be

$$ hi(ai) = gi(w{ji}hj(aj) + w{ki}hk(ak)) \times fi(a) $$

Let us consider an example problem from the paper to illustrate the workings of a CRM.

The Train Problem

The "East-West Challenge" (Michie et al., 1994), participants were supposed to distinguish eastbound trains from westbound ones using properties of the carriages and their loads (the engines properties are not used), using pictorial descriptions like these (T1 is eastbound and T2 is westbound):

Our CRM aims to take these features as input, predict the direction of the train, and generate explanations for the prediction.

The first layer of the CRM is composed of primitive features, like

math \begin{align*} C_1 &: p(X) \leftarrow (has\_car(X,Y), short(Y)) \\ C_2 &: p(X) \leftarrow (has\_car(X,Y), closed(Y)) \\ C_3 &: p(X) \leftarrow (has\_car(X,Y), short(Y), closed(Y)) \\ C_4 &: p(X) \leftarrow (has\_car(X,Y), has\_car(X,Z), short(Y), closed(Z)) \end{align*}

Here, we will assume that predicates like $has_car/2$, $short/1$, $closed/1$, $long/1$, $has_load/3$, etc., are defined as part of the background $B$, and capture the situation shown diagrammatically. That is,

math \begin{align*} B = \{~ & has\_car(t1,c1\_1), has\_car(t1,c1\_2), \ldots \\ & long(c1\_1), closed(c1\_1), has\_load(c1\_1,square,3), \ldots \\ & has\_car(t2,c2\_1), has\_car(t2,c2\_2), \ldots ~\}. \end{align*}

Then the corresponding feature-function values are:

$$ \begin{align} f{C1,B}(t1) = f1(t1) = 1; &~ f1(t2) = 1; \ f{C2,B}(t1) = f2(t1) = 1; &~ f2(t2) = 1; \ f{C3,B}(t1) = f3(t1) = 1; &~ f3(t2) = 0; \ f{C4,B}(t1) = f4(t1) = 1; &~ f4(t2) = 1. \end{align} $$

Given a way to generate the predicate part of the CRM, the arithmetic part is generated using the same way as in a standard neural network using backpropagation, with each neuron ($ni$) being _dropped out based on its predicate function ($f_i$).

Explanations Generated by CRM

Given a (relational) data instance representing a train with some cars and their properties, as shown below, a CRM generates structured explanations shown below.

We use Layerwise-Relevance Propagation (LRP) (Bach et al., 2015; Binder et al., 2016) to calculate the most relevant neurons in the penultimate layer of the CRM, just before the final two neuron classification layer. We then find the ancestral graph of this neuron to explain the prediction.

Explaination ancestral graph

Authors

  • CRM's core codebase is developed jointly with Devanshu Shah @Devanshu24.

How to run our code

Environment setup: requirements.txt

```console $ python3 main.py -h usage: main.py [-h] -f FILE -o OUTPUT -n NUM_EPOCHS [-s SAVED_MODEL] [-e] [-v] [-g]

CRM; Example: python3 main.py -f inp.file -o out.file -n 20 -s saved_model.pt -e -v

optional arguments: -h, --help show this help message and exit -f FILE, --file FILE input file -o OUTPUT, --output OUTPUT output file -n NUMEPOCHS, --num-epochs NUMEPOCHS number of epochs -s SAVEDMODEL, --saved-model SAVEDMODEL location of saved model -e, --explain get explanations for predictions -v, --verbose get verbose outputs -g, --gpu run model on gpu

```

For reference, some example commands are in command.hist.

Caution! Some result or model files might get overwritten. Check the directory structure before running. Honestly, we did not properly care about this aspect. Over time, we will fix this simple thing. Also, our results get saved in the ./data directory, which is not good. We will work on these :-)

How to cite our paper

Bibtex:

@article{srinivasan2023composition, title={Composition of relational features with an application to explaining black-box predictors}, author={Srinivasan, Ashwin and Baskar, A and Dash, Tirtharaj and Shah, Devanshu}, journal={Machine Learning}, pages={1--42}, year={2023}, doi={10.1007/s10994-023-06399-6}, publisher={Springer}, url={https://doi.org/10.1007/s10994-023-06399-6} }

Owner

  • Name: Tirtharaj Dash
  • Login: tirtharajdash
  • Kind: user
  • Location: California, USA
  • Company: University of California, San Diego

I work on Deep Learning, Neuro-Symbolic Models, Graph Representation Learning, and Machine Learning for Science.

GitHub Events

Total
Last Year

Dependencies

requirements.txt pypi
  • absl-py =1.1.0=pypi_0
  • aiohttp =3.8.1=pypi_0
  • aiosignal =1.2.0=pypi_0
  • alembic =1.7.7=pypi_0
  • antlr4-python3-runtime =4.9.3=pypi_0
  • anyio =2.2.0=py38h06a4308_1
  • argon2-cffi =20.1.0=py38h27cfd23_1
  • async-timeout =4.0.2=pypi_0
  • async_generator =1.10=pyhd3eb1b0_0
  • attrs =21.2.0=pyhd3eb1b0_0
  • autopage =0.5.0=pypi_0
  • autopep8 =1.6.0=pypi_0
  • babel =2.9.1=pyhd3eb1b0_0
  • backcall =0.2.0=pyhd3eb1b0_0
  • backports-entry-points-selectable =1.1.1=pypi_0
  • bleach =4.0.0=pyhd3eb1b0_0
  • brotlipy =0.7.0=py38h27cfd23_1003
  • ca-certificates =2021.10.26=h06a4308_2
  • cachetools =5.0.0=pypi_0
  • captum =0.4.1=pypi_0
  • certifi =2021.10.8=py38h06a4308_0
  • cffi =1.14.6=py38h400218f_0
  • cfgv =3.3.1=pypi_0
  • charset-normalizer =2.0.4=pyhd3eb1b0_0
  • click =8.0.4=pypi_0
  • cliff =3.10.1=pypi_0
  • cloudpickle =2.1.0=pypi_0
  • cmaes =0.8.2=pypi_0
  • cmd2 =2.4.0=pypi_0
  • colorlog =6.6.0=pypi_0
  • cryptography =35.0.0=py38hd23ed53_0
  • cycler =0.11.0=pypi_0
  • dataclasses =0.6=pypi_0
  • debugpy =1.5.1=py38h295c915_0
  • decorator =5.1.0=pyhd3eb1b0_0
  • defusedxml =0.7.1=pyhd3eb1b0_0
  • deprecated =1.2.13=pypi_0
  • dill =0.3.4=pypi_0
  • distlib =0.3.3=pypi_0
  • entrypoints =0.3=py38_0
  • filelock =3.4.0=pypi_0
  • frozenlist =1.3.0=pypi_0
  • fsspec =2022.5.0=pypi_0
  • future =0.18.2=pypi_0
  • google-auth =2.8.0=pypi_0
  • google-auth-oauthlib =0.4.6=pypi_0
  • greenlet =1.1.2=pypi_0
  • grpcio =1.43.0=pypi_0
  • hydra-core =1.2.0=pypi_0
  • identify =2.4.0=pypi_0
  • idna =3.2=pyhd3eb1b0_0
  • importlib-metadata =4.8.1=py38h06a4308_0
  • importlib-resources =5.4.0=pypi_0
  • importlib_metadata =4.8.1=hd3eb1b0_0
  • iniconfig =1.1.1=pypi_0
  • ipykernel =6.4.1=py38h06a4308_1
  • ipython =7.29.0=py38hb070fc8_0
  • ipython_genutils =0.2.0=pyhd3eb1b0_1
  • ipywidgets =7.6.5=pypi_0
  • jedi =0.18.0=py38h06a4308_1
  • jinja2 =3.0.2=pyhd3eb1b0_0
  • joblib =1.1.0=pypi_0
  • json5 =0.9.6=pyhd3eb1b0_0
  • jsonschema =3.2.0=pyhd3eb1b0_2
  • jupyter-http-over-ws =0.0.8=pypi_0
  • jupyter_client =7.0.6=pyhd3eb1b0_0
  • jupyter_core =4.9.1=py38h06a4308_0
  • jupyter_server =1.4.1=py38h06a4308_0
  • jupyterlab =3.2.1=pyhd8ed1ab_0
  • jupyterlab-widgets =1.0.2=pypi_0
  • jupyterlab_pygments =0.1.2=py_0
  • jupyterlab_server =2.8.2=pyhd3eb1b0_0
  • kiwisolver =1.3.2=pypi_0
  • ld_impl_linux-64 =2.35.1=h7274673_9
  • libffi =3.3=he6710b0_2
  • libgcc-ng =9.3.0=h5101ec6_17
  • libgomp =9.3.0=h5101ec6_17
  • libsodium =1.0.18=h7b6447c_0
  • libstdcxx-ng =9.3.0=hd4cf53a_17
  • mako =1.2.0=pypi_0
  • markdown =3.3.7=pypi_0
  • markupsafe =2.0.1=py38h27cfd23_0
  • matplotlib =3.4.3=pypi_0
  • matplotlib-inline =0.1.2=pyhd3eb1b0_2
  • mistune =0.8.4=py38h7b6447c_1000
  • msgpack =1.0.3=pypi_0
  • multidict =6.0.2=pypi_0
  • nbclassic =0.2.6=pyhd3eb1b0_0
  • nbclient =0.5.3=pyhd3eb1b0_0
  • nbconvert =6.1.0=py38h06a4308_0
  • nbformat =5.1.3=pyhd3eb1b0_0
  • ncurses =6.3=h7f8727e_2
  • nest-asyncio =1.5.1=pyhd3eb1b0_0
  • networkx =2.6.3=pypi_0
  • nodeenv =1.6.0=pypi_0
  • notebook =6.4.6=py38h06a4308_0
  • numpy =1.21.4=pypi_0
  • nvidia-ml-py =11.450.51=pypi_0
  • nvitop =0.5.3=pypi_0
  • oauthlib =3.2.0=pypi_0
  • omegaconf =2.2.2=pypi_0
  • openssl =1.1.1l=h7f8727e_0
  • optuna =2.10.0=pypi_0
  • packaging =21.2=pypi_0
  • pandas =1.3.4=pypi_0
  • pandocfilters =1.4.3=py38h06a4308_1
  • parso =0.8.2=pyhd3eb1b0_0
  • pbr =5.8.1=pypi_0
  • pexpect =4.8.0=pyhd3eb1b0_3
  • pickleshare =0.7.5=pyhd3eb1b0_1003
  • pillow =8.4.0=pypi_0
  • pip =22.1.2=pypi_0
  • platformdirs =2.4.0=pypi_0
  • pluggy =1.0.0=pypi_0
  • pre-commit =2.15.0=pypi_0
  • prettytable =3.2.0=pypi_0
  • prometheus_client =0.12.0=pyhd3eb1b0_0
  • prompt-toolkit =3.0.20=pyhd3eb1b0_0
  • protobuf =3.19.4=pypi_0
  • psutil =5.9.0=pypi_0
  • ptyprocess =0.7.0=pyhd3eb1b0_2
  • py =1.10.0=pypi_0
  • pyasn1 =0.4.8=pypi_0
  • pyasn1-modules =0.2.8=pypi_0
  • pycodestyle =2.8.0=pypi_0
  • pycparser =2.21=pyhd3eb1b0_0
  • pydeprecate =0.3.0=pypi_0
  • pygments =2.10.0=pyhd3eb1b0_0
  • pyopenssl =21.0.0=pyhd3eb1b0_1
  • pyparsing =2.4.7=pypi_0
  • pyperclip =1.8.2=pypi_0
  • pyrsistent =0.18.0=py38heee7806_0
  • pysocks =1.7.1=py38h06a4308_0
  • pytest =6.2.5=pypi_0
  • python =3.8.12=h12debd9_0
  • python-dateutil =2.8.2=pyhd3eb1b0_0
  • pytorch-lightning =1.6.4=pypi_0
  • pytz =2021.3=pyhd3eb1b0_0
  • pyyaml =5.4.1=pypi_0
  • pyzmq =22.2.1=py38h295c915_1
  • ray =1.11.0=pypi_0
  • readline =8.1=h27cfd23_0
  • redis =4.1.4=pypi_0
  • requests =2.26.0=pyhd3eb1b0_0
  • requests-oauthlib =1.3.1=pypi_0
  • rsa =4.8=pypi_0
  • scikit-learn =1.0.1=pypi_0
  • scipy =1.7.2=pypi_0
  • seaborn =0.11.2=pypi_0
  • send2trash =1.8.0=pyhd3eb1b0_1
  • setuptools =58.0.4=py38h06a4308_0
  • six =1.16.0=pyhd3eb1b0_0
  • sniffio =1.2.0=py38h06a4308_1
  • sqlalchemy =1.4.32=pypi_0
  • sqlite =3.36.0=hc218d9a_0
  • stevedore =3.5.0=pypi_0
  • tabulate =0.8.9=pypi_0
  • tensorboard =2.9.1=pypi_0
  • tensorboard-data-server =0.6.1=pypi_0
  • tensorboard-plugin-wit =1.8.1=pypi_0
  • tensorboardx =2.5=pypi_0
  • termcolor =1.1.0=pypi_0
  • terminado =0.9.4=py38h06a4308_0
  • testpath =0.5.0=pyhd3eb1b0_0
  • threadpoolctl =3.0.0=pypi_0
  • tk =8.6.11=h1ccaba5_0
  • toml =0.10.2=pypi_0
  • torch =1.8.1
  • torchaudio =0.8.1=pypi_0
  • torchmetrics =0.9.1=pypi_0
  • torchvision =0.9.1
  • tornado =6.1=py38h27cfd23_0
  • tqdm =4.62.3=pypi_0
  • traitlets =5.1.1=pyhd3eb1b0_0
  • typing-extensions =4.0.0=pypi_0
  • urllib3 =1.26.7=pyhd3eb1b0_0
  • virtualenv =20.10.0=pypi_0
  • wcwidth =0.2.5=pyhd3eb1b0_0
  • webencodings =0.5.1=py38_1
  • werkzeug =2.1.2=pypi_0
  • wheel =0.37.0=pyhd3eb1b0_1
  • widgetsnbextension =3.5.2=pypi_0
  • wrapt =1.14.0=pypi_0
  • xz =5.2.5=h7b6447c_0
  • yarl =1.7.2=pypi_0
  • zeromq =4.3.4=h2531618_0
  • zipp =3.6.0=pyhd3eb1b0_0
  • zlib =1.2.11=h7b6447c_3
requirements_dell5810.txt pypi
  • aiosignal =1.3.1=pypi_0
  • anyio =2.2.0=py38h06a4308_1
  • argon2-cffi =20.1.0=py38h27cfd23_1
  • async_generator =1.10=pyhd3eb1b0_0
  • attrs =21.2.0=pyhd3eb1b0_0
  • babel =2.9.1=pyhd3eb1b0_0
  • backcall =0.2.0=pyhd3eb1b0_0
  • blas =1.0=mkl
  • bleach =4.0.0=pyhd3eb1b0_0
  • brotlipy =0.7.0=py38h27cfd23_1003
  • ca-certificates =2023.01.10=h06a4308_0
  • certifi =2022.12.7=py38h06a4308_0
  • cffi =1.14.6=py38h400218f_0
  • charset-normalizer =2.0.4=pyhd3eb1b0_0
  • click =8.1.3=pypi_0
  • cryptography =35.0.0=py38hd23ed53_0
  • cudatoolkit =10.1.243=h6bb024c_0
  • debugpy =1.5.1=py38h295c915_0
  • decorator =5.1.0=pyhd3eb1b0_0
  • defusedxml =0.7.1=pyhd3eb1b0_0
  • dill =0.3.6=pypi_0
  • distlib =0.3.6=pypi_0
  • entrypoints =0.3=py38_0
  • filelock =3.12.0=pypi_0
  • freetype =2.11.0=h70c0345_0
  • frozenlist =1.3.3=pypi_0
  • giflib =5.2.1=h7b6447c_0
  • grpcio =1.51.3=pypi_0
  • idna =3.2=pyhd3eb1b0_0
  • importlib-metadata =4.8.1=py38h06a4308_0
  • importlib_metadata =4.8.1=hd3eb1b0_0
  • intel-openmp =2021.4.0=h06a4308_3561
  • ipykernel =6.4.1=py38h06a4308_1
  • ipython =7.29.0=py38hb070fc8_0
  • ipython_genutils =0.2.0=pyhd3eb1b0_1
  • jedi =0.18.0=py38h06a4308_1
  • jinja2 =3.0.2=pyhd3eb1b0_0
  • joblib =1.2.0=pypi_0
  • jpeg =9e=h7f8727e_0
  • json5 =0.9.6=pyhd3eb1b0_0
  • jsonschema =3.2.0=pyhd3eb1b0_2
  • jupyter_client =7.0.6=pyhd3eb1b0_0
  • jupyter_core =4.9.1=py38h06a4308_0
  • jupyter_server =1.4.1=py38h06a4308_0
  • jupyterlab =3.2.1=pyhd8ed1ab_0
  • jupyterlab_pygments =0.1.2=py_0
  • jupyterlab_server =2.8.2=pyhd3eb1b0_0
  • lcms2 =2.12=h3be6417_0
  • ld_impl_linux-64 =2.35.1=h7274673_9
  • libffi =3.3=he6710b0_2
  • libgcc-ng =9.3.0=h5101ec6_17
  • libgomp =9.3.0=h5101ec6_17
  • libpng =1.6.37=hbc83047_0
  • libsodium =1.0.18=h7b6447c_0
  • libstdcxx-ng =9.3.0=hd4cf53a_17
  • libtiff =4.2.0=h85742a9_0
  • libuv =1.40.0=h7b6447c_0
  • libwebp =1.2.2=h55f646e_0
  • libwebp-base =1.2.2=h7f8727e_0
  • lz4-c =1.9.3=h295c915_1
  • markupsafe =2.0.1=py38h27cfd23_0
  • matplotlib-inline =0.1.2=pyhd3eb1b0_2
  • mistune =0.8.4=py38h7b6447c_1000
  • mkl =2021.4.0=h06a4308_640
  • mkl-service =2.4.0=py38h7f8727e_0
  • mkl_fft =1.3.1=py38hd3c417c_0
  • mkl_random =1.2.2=py38h51133e4_0
  • msgpack =1.0.5=pypi_0
  • nbclassic =0.2.6=pyhd3eb1b0_0
  • nbclient =0.5.3=pyhd3eb1b0_0
  • nbconvert =6.1.0=py38h06a4308_0
  • nbformat =5.1.3=pyhd3eb1b0_0
  • ncurses =6.3=h7f8727e_2
  • nest-asyncio =1.5.1=pyhd3eb1b0_0
  • ninja =1.10.2=h06a4308_5
  • ninja-base =1.10.2=hd09550d_5
  • notebook =6.4.6=py38h06a4308_0
  • numpy =1.22.3=py38he7a7128_0
  • numpy-base =1.22.3=py38hf524024_0
  • openssl =1.1.1t=h7f8727e_0
  • packaging =23.1=pyhd8ed1ab_0
  • pandas =2.0.1=pypi_0
  • pandocfilters =1.4.3=py38h06a4308_1
  • parso =0.8.2=pyhd3eb1b0_0
  • pexpect =4.8.0=pyhd3eb1b0_3
  • pickleshare =0.7.5=pyhd3eb1b0_1003
  • pillow =9.0.1=py38h22f2fdc_0
  • pip =23.1.2=pyhd8ed1ab_0
  • platformdirs =3.5.0=pypi_0
  • prometheus_client =0.12.0=pyhd3eb1b0_0
  • prompt-toolkit =3.0.20=pyhd3eb1b0_0
  • protobuf =3.20.3=pypi_0
  • ptyprocess =0.7.0=pyhd3eb1b0_2
  • pycparser =2.21=pyhd3eb1b0_0
  • pygments =2.10.0=pyhd3eb1b0_0
  • pyopenssl =21.0.0=pyhd3eb1b0_1
  • pyrsistent =0.18.0=py38heee7806_0
  • pysocks =1.7.1=py38h06a4308_0
  • python =3.8.12=h12debd9_0
  • python-dateutil =2.8.2=pyhd3eb1b0_0
  • pytorch =1.7.0=py3.8_cuda10.1.243_cudnn7.6.3_0
  • pytz =2021.3=pyhd3eb1b0_0
  • pyyaml =6.0=pypi_0
  • pyzmq =22.2.1=py38h295c915_1
  • ray =2.4.0=pypi_0
  • readline =8.1=h27cfd23_0
  • requests =2.26.0=pyhd3eb1b0_0
  • scikit-learn =1.0.1=pypi_0
  • scipy =1.10.1=pypi_0
  • send2trash =1.8.0=pyhd3eb1b0_1
  • setuptools =58.0.4=py38h06a4308_0
  • six =1.16.0=pyhd3eb1b0_0
  • sniffio =1.2.0=py38h06a4308_1
  • sqlite =3.36.0=hc218d9a_0
  • tabulate =0.9.0=pypi_0
  • tensorboardx =2.6=pypi_0
  • terminado =0.9.4=py38h06a4308_0
  • testpath =0.5.0=pyhd3eb1b0_0
  • threadpoolctl =3.1.0=pypi_0
  • tk =8.6.11=h1ccaba5_0
  • torchaudio =0.7.0=py38
  • torchvision =0.8.0=py38_cu101
  • tornado =6.1=py38h27cfd23_0
  • tqdm =4.65.0=pypi_0
  • traitlets =5.1.1=pyhd3eb1b0_0
  • typing_extensions =4.5.0=py38h06a4308_0
  • tzdata =2023.3=pypi_0
  • urllib3 =1.26.7=pyhd3eb1b0_0
  • virtualenv =20.21.0=pypi_0
  • wcwidth =0.2.5=pyhd3eb1b0_0
  • webencodings =0.5.1=py38_1
  • wheel =0.37.0=pyhd3eb1b0_1
  • xz =5.2.5=h7b6447c_0
  • zeromq =4.3.4=h2531618_0
  • zipp =3.6.0=pyhd3eb1b0_0
  • zlib =1.2.11=h7b6447c_3
  • zstd =1.4.9=haebb681_0