https://github.com/berenslab/dependence-measures-medical-imaging

This repository contains the code for the paper "Benchmarking Dependence Measures to Prevent Shortcut Learning in Medical Imaging".

https://github.com/berenslab/dependence-measures-medical-imaging

Science Score: 26.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (15.0%) to scientific vocabulary

Keywords

disentangled-representations domain-shift shortcut-learning
Last synced: 9 months ago · JSON representation

Repository

This repository contains the code for the paper "Benchmarking Dependence Measures to Prevent Shortcut Learning in Medical Imaging".

Basic Info
Statistics
  • Stars: 1
  • Watchers: 2
  • Forks: 0
  • Open Issues: 0
  • Releases: 0
Topics
disentangled-representations domain-shift shortcut-learning
Created almost 2 years ago · Last pushed 10 months ago
Metadata Files
Readme License

README.md

Benchmarking Dependence Measures to Prevent Shortcut Learning in Medical Imaging

This repository contains the code to reproduce the results from the paper "Benchmarking Dependence Measures to Prevent Shortcut Learning in Medical Imaging", which was accepted to the 15th International Workshop on Machine Learning in Medical Imaging (MLMI 2024).

We present a comprehensive performance comparison of dependency measures to prevent shortcut learning in medical imaging.

drawing

Installation

Set up a python environment with a python version 3.10. Then, download the repository, activate the environment and install all other dependencies with cd dependence-measures-medical-imaging pip install --editable .

This installs the code in src as an editable package and all the dependencies in requirements.txt.

Organization of the repo

  • configs: Configuration files for all experiments.
  • scripts: Slurm scripts for model training and hyperparameter sweeps.
  • src: Main source code to run the experiments.
    • data: Pytorch datasets and scripts/info to download data.
    • models: Pytorch lightning module to train models to prevent shortcut learning with different methods.
    • eval: Model evaluation with kNN classifiers and embedding plots.
  • train.py: Main training script to train k-fold cross validation (and optional hyperparameter sweeps).

Usage

Download public datasets

First, you need to download the two data sets Morpho-MNIST and CheXpert. For Morpho-MNIST we have a download script: python src/data/download_data/load_morpho_mnist.py -d path-to-dataset-directory -v True For CheXpert you need to register, hence we provide additional information on how to register and download the dataset: load_chexpert.txt.

Training

To run k-fold cross-validation for one method you need to hand over a config file to the train script. For example, for MINE with the Morpho-MNIST dataset the comand-line interface is python src/train.py -tc configs/morpho-mnist/mine.yaml Note: The dataset_path needs to be adjusted in the config file.

To run the code on a slurm cluster, we provide a bash script: sbatch scripts/train.sh configs/morpho-mnist/mine.yaml

Run hyperparameter sweeps (wandb)

Initialize the sweep with python src/utils/sweep_init.py -sc configs/example_sweep.yaml This will print out the sweep_id that you can hand over to the script to start multiple runs (10 in this case) on a slurm cluster sh scripts/sweep.sh 10 configs/morpho-mnist/mine.yaml sweep_id

Evaluation

To evaluate the trained models for the confusion matrix of kNN classifier accuracy for one model run python src/eval/knn_classifier.py -c model_config -ckpts model_checkpoints To generate the embedding plots of the paper run python src/eval/embeddings.py -cfgs list_of_model_configs -ckpts list_of_model_checkpoints

Cite

If you find our code or paper useful, please consider citing this work: bibtex @InProceedings{mueller2025benchmarking, title = {Benchmarking Dependence Measures to Prevent Shortcut Learning in Medical Imaging}, author = {M\"uller, Sarah and Fay, Louisa and Koch, Lisa M. and Gatidis, Sergios and K\"ustner, Thomas and Berens, Philipp}, booktitle = {Machine Learning in Medical Imaging}, year = {2025}, publisher = {Springer Nature Switzerland}, pages = {53--62}, isbn = {978-3-031-73290-4}, }

Owner

  • Name: Berens Lab @ University of Tübingen
  • Login: berenslab
  • Kind: organization
  • Email: philipp.berens@uni-tuebingen.de
  • Location: Tübingen, Germany

Department of Data Science at the Hertie Institute for AI in Brain Health, University of Tübingen

GitHub Events

Total
  • Watch event: 1
  • Push event: 2
Last Year
  • Watch event: 1
  • Push event: 2

Dependencies

requirements.txt pypi
  • lightning ==2.2.1
  • matplotlib ==3.9.0
  • numpy ==1.26.4
  • omegaconf ==2.3.0
  • pandas ==2.2.2
  • scikit-learn ==1.5.0
  • torchmetrics ==1.3.2
  • torchvision ==0.17.2
  • wandb ==0.17.0
setup.py pypi