https://github.com/berenslab/interpretable-deep-survival-analysis

An interpretable end-to-end CNN for disease progression modeling that predicts late AMD onset (MICCAI 2024)

https://github.com/berenslab/interpretable-deep-survival-analysis

Science Score: 36.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
  • DOI references
    Found 5 DOI reference(s) in README
  • Academic publication links
    Links to: arxiv.org, medrxiv.org, ncbi.nlm.nih.gov, zenodo.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (14.4%) to scientific vocabulary

Keywords

cnn coxph deep-learning disease-prediction image-based interpretability survival-analysis
Last synced: 5 months ago · JSON representation

Repository

An interpretable end-to-end CNN for disease progression modeling that predicts late AMD onset (MICCAI 2024)

Basic Info
  • Host: GitHub
  • Owner: berenslab
  • Language: Python
  • Default Branch: main
  • Homepage:
  • Size: 3.1 MB
Statistics
  • Stars: 0
  • Watchers: 4
  • Forks: 0
  • Open Issues: 0
  • Releases: 0
Topics
cnn coxph deep-learning disease-prediction image-based interpretability survival-analysis
Created over 1 year ago · Last pushed over 1 year ago
Metadata Files
Readme

README.md

Interpretable-by-design deep survival analysis for disease progression modeling

Accepted at MICCAI 2024

This repository contains the code for the paper "Interpretable-by-design Deep Survival Analysis for Disease Progression Modeling". The code is based on PyTorch and is used to train and evaluate a deep learning model for survival analysis on fundus images. The model offers built-in interpretability by yielding an evidence map of local disease risk which is subsequently simply aggregated to a final risk prediction. The model is based on a combination of a Sparse BagNet and a Cox proportional hazards model. As a sample use case, it is trained to predict the risk of conversion to age-related macular degeneration at different time points and its performance is compared to that of black-box SOTA baseline models.


Figure 1: Model architecture
Figure 1: Model architecture.

Table 1: Model performance

Pre-requisites

  • Obtain the AREDS data (see Data)
  • Install the conda environment using conda env create -f requirements.yml (or use requirements_without_R.yml which installs much faster but lacks AUPRC metric support) and activate it with conda activate amd-surv.
  • Log in to wandb with the command wandb login. Create an account if you don't have one.

Load and evaluate the pre-trained interpretable-by-design survival model on AREDS data

  • Retrieve the model weights 7ufjvvnz_best.pth from here and place it in the checkpoints directory.
  • Follow the instructions in evaluate_survival.ipynb to load the model and evaluate it.
  • To evaluate the pre-trained baseline models, retrieve the other weights files and repeat the procedure.

Alternatively, train the model yourself

The model can be trained and evaluated using the following command:
python train_and_evaluate_survival.py --config configs/sparsebagnet_cox.yml

Similarly, the baseline models can be trained and evaluated using the files in configs/babenko/ and configs/yan/. The model configs serve to train one model for each inquired point in time (years 1 to 5).

Modify the configuration

The training is configured using a yaml file. The following parameters can be set:

YAML

```yaml

Path to the metadata csv file relative to the project directory.

metadatacsv: data/metadatasurv.csv # input file with image paths etc. Make sure to select one with stereo images if selecting usestereoimages: true below.

Path to image files, relative to the data directory in dirs.yml.

image_dir: images-f2-1024px

CNN training

cnn: project: wandb-project-name # The name of the wandb project. This code requires wandb to be installed and logged in. runid: none # Needed to load a checkpoint, optional resumetraining: false # Load model from last checkpoint, optional loadbestmodel: false # Load model from best checkpoint, optional testrun: enabled: false # If true, only use a small subset of the data
size: 100 # Ignored if enabled=false train
setfraction: 1.0 # For debuggung use <1.0. Overwrites testrun.size valsetfraction: 1.0 # For debugging use <1.0 survival_times: [2, 4, 6, 8, 10] # Times to evaluate the model at, in half-years. You can pass e.g. "[2]" in combination with "loss: clf" to train a classifier model for time 2. gpu: 0 # Index of GPU to use seed: 123 # Random seed for reproducibility batch_size: 8 # Batch size num_epochs: 50 # Max. number of epochs stop_after_epochs: 10 # Stop training after this number of epochs without improvement num_workers: 8 # Number of workers for data loading img_size: 350 # Height and width of input images network: sparsebagnet-surv # Combination of {resnet, sparsebagnet (or, equally: bagnet), inceptionv3} and {surv} lr: 0.000016 # Learning rate sparsity_lambda: 0.000006 # Sparsity loss coefficient (only applied to BagNet models) loss: cox # cox or clf. Determines model type. clf is for classification models using ce-loss, trained for only one survival_time use_stereo_pairs: false # If true, use stereo image pairs and average their predictions after non-linear activation. Use correct metadata_csv that includes images of both sides of an eye! model_selection_metric: ibs # ibs, aunbc, auc or bce_loss eval_sets: [val, test] # (list of) set(s) to evaluate the model on, e.g. "[val]" optimizer: adam # adam, adamw or sgd weight_decay: 0.0 # Set to >0.0 to use weight decay scheduler: none # onecyclelr, cosineannealing, cosineannealingrestarts or none schedulercosinelen: 0 # Length of a cosine period in epoch units. If not passed, uses epochs/2. Ignored if scheduler is not cosineannealing or cosineannealingrestarts. warmupepochs: 0 # Number of epochs to warmup the learning rate, set to 0 for no warmup augmentation: true # If true, use data augmentation as in Huang et al. 2023 balancing: false # Class balancing, currently not supported. Set to false numclasses: 12 # AMD severity scale length for dataloader, do not change ```


The model takes below 12h to train and evaluate on a single NVIDIA GeForce RTX 2080 Ti GPU.

Data

Upon request, the AREDS data can be accessed from the dbGaP repository. The tabular data needs to be parsed to one metadata file with the rows representing individual macula-centered (F2) images. <!--Throughout the study, two identifiers for participants have been used. Match participants to ID2 identifiers using the mapping file that comes with the AREDS data.
-->

The AREDS data directory as defined in configs/dirs.yml should contain:
- a folder as set the model config (image_dir), the directory of image data. The images should be organized such that the metadata CSV image paths map to the image files, relative to the image_dir.

The <project>/data directory should contain the following files:
- metadata_surv.csv: Survival metadata table that maps screening data (event, duration) to image records. An example entry is provided in data/metadata_surv example.csv with important columns as follows:

Columns

  • patient_id, the eye identifier
  • visit_number, VISNO
  • image_eye, right or left eye
  • image_field, F2
  • image_side, RS or LS
  • image_file, file name
  • image_path, relative path incl. file name from data dir specified in dirs.yml
  • duration, relative time to the first visit where late AMD was diagnosed in units of visits (half years)
  • event, 1 if any record of this eye converted to late AMD, 0 else
  • diagnosisamdgrade_12c, AMDSEVRE and AMDSEVLE: the AMD severity scale score starting at 0

  • metadatasurvstereo.csv: Image metadata table with the same structure, but of records where both views (LS and RS) exist (for Babenko et al. baseline model).

Both datasets can be created from a parsed AREDS metadata file using save_survival_metadata.ipynb.

Illustrative figures

Figure 2: Example risk evidence maps and survival curve predictions
Figure 2: Example risk evidence maps (a) and the resulting survival curve predictions (b).

Figure 3: Example of most important image patches
Figure 3: Examples of most predictive image patches which could be provided to clinicians (purple: higher risk, green: lower risk).

Credits

This work is mainly based on

Deep CoxPH modeling - Katzman, J. et al. "DeepSurv: personalized treatment recommender system using a Cox proportional hazards deep neural network." BMC Med Res Methodol 18, 24 (2018).

The Sparse BagNet for classification - Djoumessi, K. et al. "Sparse activations for interpretable disease grading." Medical Imaging with Deep Learning (2023).

Baselines: SOTA end-to-end AMD progression models - Babenko, B. et al. "Predicting progression of age-related macular degeneration from fundus images using deep learning." arXiv preprint (2019). - Yan, Q. et al. "Deep-learning-based prediction of late age-related macular degeneration progression." Nat Mach Intell 2, 141–150 (2020). - Note: We used the authors' model variant that does not rely on gene data and is trained end-to-end on fundus images.

This work includes code adaptations from

  • scikit-survival: We adapted the Breslow estimator to store the baseline hazard function and survival function and subsequently init the estimator from the saved data.
  • auton-survival: We adapted the CoxPH loss implementation from the authors' code.
  • BagNet (Brendel and Bethge, 2019) and Sparse BagNet (Djoumessi et al., 2023): We adapted the BagNet implementations to survival modelling.
  • Huang et al., 2023: We adapted the data augmentation from the authors' code.

Cite

bibtex @InProceedings{gervelmeyer2024interpretable, author = { Gervelmeyer, Julius and Müller, Sarah and Djoumessi, Kerol and Merle, David and Clark, Simon J. and Koch, Lisa and Berens, Philipp}, title = { { Interpretable-by-design Deep Survival Analysis for Disease Progression Modeling } }, booktitle = {Proceedings of Medical Image Computing and Computer Assisted Intervention -- MICCAI 2024}, year = {2024}, publisher = {Springer Nature Switzerland}, volume = {LNCS 15010}, month = {October}, pages = {502--512}, URL = {https://papers.miccai.org/miccai-2024/421-Paper1325.html} }



Back to top

Owner

  • Name: Berens Lab @ University of Tübingen
  • Login: berenslab
  • Kind: organization
  • Email: philipp.berens@uni-tuebingen.de
  • Location: Tübingen, Germany

Department of Data Science at the Hertie Institute for AI in Brain Health, University of Tübingen

GitHub Events

Total
  • Watch event: 3
  • Push event: 2
Last Year
  • Watch event: 3
  • Push event: 2

Dependencies

environment.yml conda
  • mscorefonts 0.0.1.*
  • pandas 1.4.2.*
  • pip 22.1.2.*
  • python 3.9.5.*
  • r-base
  • r-survival
  • r-timeroc
  • rpy2