https://github.com/animesh/explainn

ExplaiNN: interpretable and transparent neural networks for genomics

https://github.com/animesh/explainn

Science Score: 10.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
  • .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (10.5%) to scientific vocabulary
Last synced: 10 months ago · JSON representation

Repository

ExplaiNN: interpretable and transparent neural networks for genomics

Basic Info
  • Host: GitHub
  • Owner: animesh
  • License: mit
  • Default Branch: main
  • Size: 6.74 MB
Statistics
  • Stars: 0
  • Watchers: 0
  • Forks: 0
  • Open Issues: 0
  • Releases: 0
Fork of wassermanlab/ExplaiNN
Created about 4 years ago · Last pushed about 4 years ago

https://github.com/animesh/ExplaiNN/blob/main/

# ExplaiNN



ExplaiNN is an adaptation of neural additive models ([NAMs](https://arxiv.org/abs/2004.13912)) for genomic tasks wherein predictions are computed as a linear combination of multiple independent CNNs, each consisting of a single convolutional filter and fully connected layers. This approach brings together the expressivity of CNNs with the interpretability of linear models, providing global (cell state level) as well as local (individual sequence level) insights of the biological processes studied.

## Installation

Explainn library is available on pip and can be installed with:

```
pip install explainn==0.1.5
```

Note that torch should be installed in the environment prior to explainn. If you encounter **ERROR: No matching distribution** type of errors, try to install the following libraries first:

```
numpy==1.21.6
h5py==3.6.0
tqdm==4.64.0
pandas==1.3.5
matplotlib==3.5.2
```

## Example of training an ExplaiNN model on TF binding data

Here I give an example of how one can train and interpret an ExplaiNN model on predicting the binding of 3 TFs: FOXA1, MAX, and JUND. The dataset can be found [here](https://drive.google.com/drive/folders/1tFWWTCUoE2Jg0zrMvKKtTqEBuwkkJ1bl). 

### Initialize the model

Imports:

```python
from explainn import tools
from explainn import networks
from explainn import train
from explainn import test
from explainn import interpretation

import torch
import os
from torch import nn
from sklearn.metrics import average_precision_score
from sklearn import metrics
from matplotlib import pyplot as plt
import pandas as pd
import seaborn as sns
```

Model and parameter initialization:

```python
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Hyper parameters
num_epochs = 15
batch_size = 128
learning_rate = 0.001

dataloaders, target_labels, train_out = tools.load_datas("data/tf_peaks_TEST_sparse_Remap.h5", batch_size,
                                                         0, True)

target_labels = [i.decode("utf-8") for i in target_labels]

num_cnns = 100
input_length = 200
num_classes = len(target_labels)
filter_size = 19


model = networks.ExplaiNN(num_cnns, input_length, num_classes, filter_size).to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
```

### Train the model

Code for training the model

```python
os.makedirs("Output_weights")
weights_folder = "Output_weights"

model, train_error, test_error = train.train_explainn(dataloaders['train'],
                                                dataloaders['valid'], model,
                                                device, criterion, optimizer, num_epochs,
                                                weights_folder, name_ind, verbose=True,
                                                trim_weights=False) 

showPlot(train_error, test_error, "Loss trend", "Loss")
```



### Testing the model

```python
model.load_state_dict(torch.load("checkpoints/model_epoch_9_.pth"))

labels_E, outputs_E = test.run_test(model, dataloaders['test'], device)
pr_rec = average_precision_score(labels_E, outputs_E)

no_skill_probs = [0 for _ in range(len(labels_E[:, 0]))]
ns_fpr, ns_tpr, _ = metrics.roc_curve(labels_E[:, 0], no_skill_probs)

roc_aucs = {}
raw_aucs = {}
roc_prcs = {}
raw_prcs = {}
for i in range(len(target_labels)):
    nn_fpr, nn_tpr, threshold = metrics.roc_curve(labels_E[:, i], outputs_E[:, i])
    roc_auc_nn = metrics.auc(nn_fpr, nn_tpr)

    precision_nn, recall_nn, thresholds = metrics.precision_recall_curve(labels_E[:, i], outputs_E[:, i])
    pr_auc_nn = metrics.auc(recall_nn, precision_nn)

    raw_aucs[target_labels[i]] = nn_fpr, nn_tpr
    roc_aucs[target_labels[i]] = roc_auc_nn

    raw_prcs[target_labels[i]] = recall_nn, precision_nn
    roc_prcs[target_labels[i]] = pr_auc_nn

print(roc_prcs)
print(roc_aucs)
```

```
{'MAX': 0.825940403552367, 'FOXA1': 0.8932791261118389, 'JUND': 0.749391895435854}
{'MAX': 0.8031278582930756, 'FOXA1': 0.8065550331791597, 'JUND': 0.7463422694967192}
```

### Interpretation

#### Unit/filter annotation

Visualizing filters

```python
dataset, data_inp, data_out =tools.load_single_data("data/tf_peaks_TEST_sparse_Remap.h5", 
                                                     batch_size, 0, False)

predictions, labels = interpretation.get_explainn_predictions(dataset, model, device,
                                                              isSigmoid=True)

# only well predicted sequences
pred_full_round = np.round(predictions)
arr_comp = np.equal(pred_full_round, labels)
idx = np.argwhere(np.sum(arr_comp, axis=1) == len(target_labels)).squeeze()

data_inp = data_inp[idx, :, :]
data_out = data_out[idx, :]

dataset = torch.utils.data.TensorDataset(data_inp, data_out)
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=batch_size, shuffle=False,
                                                  num_workers=0)

activations = interpretation.get_explainn_unit_activations(data_loader, model, device)
pwms = interpretation.get_pwms_explainn(activations, data_inp, data_out, filter_size)
interpretation.pwm_to_meme(pwms, "data/explainn_filters.meme")
```

Tomtom annotation:

```bash
./../../meme-5.3.3/src/tomtom -oc MAX_JUND_FOXA1_tomtom explainn_filters.meme ../../tomtom_results/JASPAR2020_CORE_vertebrates_non-redundant_pfms_meme.txt
```

#### Output layer weights visualization

Reading the tomtom's annotation:

```python
tomtom_results = pd.read_csv("data/MAX_JUND_FOXA1_tomtom/tomtom.tsv",
                                        sep="\t",comment="#")

filters_with_min_q = tomtom_results.groupby('Query_ID').min()["q-value"]

tomtom_results = tomtom_results[["Target_ID", "Query_ID", "q-value"]]
tomtom_results = tomtom_results[tomtom_results["q-value"]<0.05]

cisbp_motifs = {}
with open("data/JASPAR2020_CORE_vertebrates_non-redundant_pfms_meme.txt") as f:
    for line in f:
        if "MOTIF" in line:
            motif = line.strip().split()[-1]
            name_m = line.strip().split()[-2]
            cisbp_motifs[name_m] = motif

filters = tomtom_results["Query_ID"].unique()
annotation = {}
for f in filters:
    t = tomtom_results[tomtom_results["Query_ID"] == f]
    target_id = t["Target_ID"]
    if len(target_id) > 5:
        target_id = target_id[:5]
    ann = "/".join([cisbp_motifs[i] for i in target_id.values])
    annotation[f] = ann

annotation = pd.Series(annotation)
```

Retrieving weights:

```python
weights = model.final.weight.detach().cpu().numpy()

filters = ["filter"+str(i) for i in range(num_cnns)]
for i in annotation.keys():
    filters[int(i.split("filter")[-1])] = annotation[i]

weight_df = pd.DataFrame(weights, index=target_labels, columns=filters)
# focusing on annotated filters only
weight_df = weight_df[[i for i in weight_df.columns if not i.startswith("filter")]]
```

Visualizing the weights:

```
plt.figure(figsize=(15, 10))
sns.clustermap(weight_df, cmap=sns.diverging_palette(145, 10, s=60, as_cmap=True),
               row_cluster=False, figsize=(30, 20), vmax=0.5, vmin=-0.5)
plt.show()
```



#### Individual unit importance

Visualizing one of the MYC/MAX filters (unit #67):

```python
unit_outputs = interpretation.get_explainn_unit_outputs(data_loader, model, device)

unit_importance = interpretation.get_specific_unit_importance(activations, model, unit_outputs, 67, target_labels)

filter_key = "filter"+str(67)
title = annotation[filter_key] if filter_key in annotation.index else filter_key
fig, ax = plt.subplots()
datas = [filt_dat for filt_dat in unit_importance]
ax.boxplot(datas, notch=True, patch_artist=True, boxprops=dict(facecolor="#228833", color="#228833"))
fig.set_size_inches(18.5, 10.5)
plt.title(title)
plt.ylabel("Unit importance")
plt.xticks(range(1, len(target_labels)+1), target_labels)
plt.xticks(rotation=90)
plt.show()
```

Owner

  • Name: Ani
  • Login: animesh
  • Kind: user
  • Location: Norway
  • Company: Norwegian University of Science and Technology

A medical graduate from Delhi University with post-graduation in bioinformatics from Jawaharlal Nehru University, India.

GitHub Events

Total
Last Year