https://github.com/chenliu-1996/imageflownet

[ICASSP 2025 Oral] ImageFlowNet: Forecasting Multiscale Image-Level Trajectories of Disease Progression with Irregularly-Sampled Longitudinal Medical Images

https://github.com/chenliu-1996/imageflownet

Science Score: 23.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org, sciencedirect.com, nature.com, ieee.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (9.6%) to scientific vocabulary

Keywords

differential-equations disease-progression icassp icassp2025 image-forecasting image-prediction latent-space medical-image-analysis medical-imaging neural-ode pytorch spatial-temporal time-series-forecasting trajectory-prediction unet
Last synced: 5 months ago · JSON representation

Repository

[ICASSP 2025 Oral] ImageFlowNet: Forecasting Multiscale Image-Level Trajectories of Disease Progression with Irregularly-Sampled Longitudinal Medical Images

Basic Info
Statistics
  • Stars: 13
  • Watchers: 1
  • Forks: 1
  • Open Issues: 0
  • Releases: 0
Topics
differential-equations disease-progression icassp icassp2025 image-forecasting image-prediction latent-space medical-image-analysis medical-imaging neural-ode pytorch spatial-temporal time-series-forecasting trajectory-prediction unet
Created almost 3 years ago · Last pushed 8 months ago
Metadata Files
Readme

README.md

[ICASSP 2025 Oral] ImageFlowNet

Forecasting Multiscale Image-Level Trajectories of Disease Progression
with Irregularly-Sampled Longitudinal Medical Images

[![ArXiv](https://img.shields.io/badge/ArXiv-ImageFlowNet-firebrick)](https://arxiv.org/abs/2406.14794) [![Slides](https://img.shields.io/badge/Slides-yellow)](https://chenliu-1996.github.io/slides/ImageFlowNet_slides.pdf) [![ICASSP](https://img.shields.io/badge/ICASSP-blue)](https://ieeexplore.ieee.org/abstract/document/10890535) [![Twitter](https://img.shields.io/twitter/follow/KrishnaswamyLab.svg?style=social)](https://twitter.com/KrishnaswamyLab) [![Twitter](https://img.shields.io/twitter/follow/ChenLiu-1996.svg?style=social)](https://twitter.com/ChenLiu_1996) [![LinkedIn](https://img.shields.io/badge/LinkedIn-ChenLiu-1996?color=blue)](https://www.linkedin.com/in/chenliu1996/) [![Github Stars](https://img.shields.io/github/stars/ChenLiu-1996/ImageFlowNet.svg?style=social&label=Stars)](https://github.com/ChenLiu-1996/ImageFlowNet/)

Krishnaswamy Lab, Yale University

This is the authors' implementation of ImageFlowNet, ICASSP 2025 (Oral).

The official codebase is maintained in the Lab GitHub repo.

A Glimpse into the Methods

Citation

@inproceedings{liu2025imageflownet, title={ImageFlowNet: Forecasting Multiscale Image-Level Trajectories of Disease Progression with Irregularly-Sampled Longitudinal Medical Images}, author={Liu, Chen and Xu, Ke and Shen, Liangbo L and Huguet, Guillaume and Wang, Zilong and Tong, Alexander and Bzdok, Danilo and Stewart, Jay and Wang, Jay C and Del Priore, Lucian V and Krishnaswamy, Smita}, booktitle={ICASSP 2025-2025 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, year={2025}, organization={IEEE} }

Abstract

Advances in medical imaging technologies have enabled the collection of longitudinal images, which involve repeated scanning of the same patients over time, to monitor disease progression. However, predictive modeling of such data remains challenging due to high dimensionality, irregular sampling, and data sparsity. To address these issues, we propose ImageFlowNet, a novel model designed to forecast disease trajectories from initial images while preserving spatial details. ImageFlowNet first learns multiscale joint representation spaces across patients and time points, then optimizes deterministic or stochastic flow fields within these spaces using a position-parameterized neural ODE/SDE framework. The model leverages a UNet architecture to create robust multiscale representations and mitigates data scarcity by combining knowledge from all patients. We provide theoretical insights that support our formulation of ODEs, and motivate our regularizations involving high-level visual features, latent space organization, and trajectory smoothness. We validate ImageFlowNet on three longitudinal medical image datasets depicting progression in geographic atrophy, multiple sclerosis, and glioblastoma, demonstrating its ability to effectively forecast disease progression and outperform existing methods. Our contributions include the development of ImageFlowNet, its theoretical underpinnings, and empirical validation on real-world datasets.

Repository Hierarchy

ImageFlowNet ├── comparison: some comparisons are in the `src` folder instead. | └── interpolation | ├── checkpoints: only for segmentor model weights. Other model weights in `results`. | ├── data: folders containing data files. | ├── brain_LUMIERE: Brain Glioblastoma | ├── brain_MS: Brain Multiple Sclerosis | └── retina_ucsf: Retinal Geographic Atrophy | ├── external_src: other repositories or code. | ├── results: generated results, including training log, model weights, and evaluation results. | └── src ├── data_utils ├── datasets ├── nn ├── preprocessing ├── utils └── *.py: some main scripts

Pre-trained weights

We have uploaded the weights for the retinal images. 1. The weights for the segmentor can be found in checkpoints/segment_retinaUCSF_seed1.pty 2. The weights for the ImageFlowNetODE models can be found in Google Drive. You can put them under results/retina_ucsf_ImageFlowNetODE_smoothness-0.100_latent-0.001_contrastive-0.010_invariance-0.000_seed_1/run_1/retina_ucsf_ImageFlowNetODE_smoothness-0.100_latent-0.001_contrastive-0.010_invariance-0.000_seed_1_best_pred_psnr.pty and results/retina_ucsf_ImageFlowNetODE_smoothness-0.100_latent-0.001_contrastive-0.010_invariance-0.000_seed_1/run_1/retina_ucsf_ImageFlowNetODE_smoothness-0.100_latent-0.001_contrastive-0.010_invariance-0.000_seed_1_best_seg_dice.pty.

Reproduce the results

Image registration

cd src/preprocessing python test_registration.py

Training a segmentation network (only for quantitative evaluation purposes)

cd src/ python train_segmentor.py

Training the main network.

``` cd src/

ImageFlowNet_{ODE}

python train2ptall.py --model ImageFlowNetODE --random-seed 1 python train2ptall.py --model ImageFlowNetODE --random-seed 1 --mode test --run-count 1

ImageFlowNet_{SDE}

python train2ptall.py --model ImageFlowNetSDE --random-seed 1 python train2ptall.py --model ImageFlowNetSDE --random-seed 1 --mode test --run-count 1 ```

Some common arguments.

--dataset-name: name of the dataset (`retina_ucsf`, `brain_ms`, `brain_gbm`) --segmentor-ckpt: the location of the segmentor model. Both for training and using the segmentor.

Ablations.

  1. Flow field formulation. python train_2pt_all.py --model ODEUNet python train_2pt_all.py --model ImageFlowNetODE

  2. Single-scale vs multiscale ODEs. python train_2pt_all.py --model ImageFlowNetODE --ode-location 'bottleneck' python train_2pt_all.py --model ImageFlowNetODE --ode-location 'all_resolutions' python train_2pt_all.py --model ImageFlowNetODE --ode-location 'all_connections' # default

  3. Visual feature regularization. python train_2pt_all.py --model ImageFlowNetODE --coeff-latent 0.1

  4. Contrastive learning regularization. python train_2pt_all.py --model ImageFlowNetODE --coeff-contrastive 0.1

  5. Trajectory smoothness regularization. python train_2pt_all.py --model ImageFlowNetODE --coeff-smoothness 0.1

Comparisons

Image interpolation/extrapolation methods. cd comparison/interpolation python run_baseline_interp.py --method linear python run_baseline_interp.py --method cubic_spline

Time-conditional UNet. cd src python train_2pt_all.py --model T_UNet --random-seed 1 --mode train python train_2pt_all.py --model T_UNet --random-seed 1 --mode test --run-count 1

Time-aware diffusion model (Image-to-Image Schrodinger Bridge) cd src python train_2pt_all.py --model I2SBUNet --random-seed 1 python train_2pt_all.py --model I2SBUNet --random-seed 1 --mode test --run-count 1

Style-based Manifold Extrapolation (Nat. Mach. Int. 2022). ``` conda deactivate conda activate stylegan

cd src/preprocessing python 04unpackretina_UCSF.py

cd ../../comparison/stylemanifoldextrapolation/stylegan2-ada-pytorch python train.py --outdir=../training-runs --data='../../../data/retinaucsf/UCSFimagesfinalunpacked_256x256/' --gpus=1

```

Datasets

  1. Retinal Geographic Atrophy dataset from METforMIN study (UCSF).
    • Paper: https://www.sciencedirect.com/science/article/pii/S2666914523001720.
    • You may contact the authors. Data may or may not be available.
  2. Brain Multiple Sclerosis dataset.
    • Paper: https://www.sciencedirect.com/science/article/pii/S1053811916307819?via%3Dihub
    • Data can be requested here: https://iacl.ece.jhu.edu/index.php?title=MSChallenge/data
    • Or more specifically, here: https://smart-stats-tools.org/lesion-challenge
  3. Brain Glioblastoma dataset.
    • Paper: https://www.nature.com/articles/s41597-022-01881-7
    • Data can be downloaded here: https://springernature.figshare.com/collections/TheLUMIEREDatasetLongitudinalGlioblastomaMRIwithExpertRANO_Evaluation/5904905/1

Data preparation and preprocessing

  1. Retinal Geographic Atrophy dataset.
  2. Put data under: data/retina_ucsf/Images/ cd src/preprocessing python 01_preprocess_retina_UCSF.py python 02_register_retina_UCSF.py python 03_crop_retina_UCSF.py

  3. Brain Multiple Sclerosis dataset.

  4. Put data under: data/brain_MS/brain_MS_images/trainX/ after unzipping. cd src/preprocessing python 01_preprocess_brain_MS.py

  5. Brain Glioblastoma dataset.

  6. Put data under: data/brain_LUMIERE/ after unzipping. cd src/preprocessing python 01_preprocess_brain_GBM.py

Segment Anything Model (SAM)

This is only used for test_registration.py to facilitate visualization. Not used anywhere else. cd `external_src/` mkdir SAM && cd SAM wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

Dependencies

We developed the codebase in a miniconda environment. How we created the conda environment: ```

Optional: Update to libmamba solver.

conda update -n base conda conda install -n base conda-libmamba-solver conda config --set solver libmamba

conda create --name imageflownet pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch -c nvidia -c anaconda -c conda-forge conda activate imageflownet conda install scikit-learn scikit-image pillow matplotlib seaborn tqdm -c pytorch -c anaconda -c conda-forge conda install read-roi -c conda-forge python -m pip install -U albumentations python -m pip install timm python -m pip install opencv-python python -m pip install git+https://github.com/facebookresearch/segment-anything.git python -m pip install monai python -m pip install torchdiffeq python -m pip install torch-ema python -m pip install torchcde python -m pip install torchsde python -m pip install phate python -m pip install psutil python -m pip install ninja

For 3D registration

python -m pip install antspyx ```

Acknowledgements

We adapted some of the code from 1. I^2SB: Image-to-Image Schrodinger Bridge

Owner

  • Name: Chen Liu
  • Login: ChenLiu-1996
  • Kind: user
  • Location: New Haven
  • Company: Yale University

CS PhD student at @KrishnaswamyLab, @YaleUniversity. Reviewing Committee member at NeurIPS, ICLR, ICML.

GitHub Events

Total
  • Watch event: 7
  • Push event: 16
Last Year
  • Watch event: 7
  • Push event: 16

Dependencies

external_src/SuperRetina/requirements.txt pypi
  • Pillow ==9.2.0
  • PyYAML ==6.0
  • imgaug ==0.4.0
  • matplotlib ==3.5.1
  • numpy ==1.22.3
  • opencv_python ==4.6.0.66
  • scikit_learn ==1.1.1
  • scipy ==1.8.0
  • torch ==1.8.1
  • torchvision ==0.9.1
  • tqdm ==4.64.0