iterative-smoothing-bridge

Code for 'Transport with Support: Data-Conditional Diffusion Bridges'

https://github.com/aaltoml/iterative-smoothing-bridge

Science Score: 54.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (12.9%) to scientific vocabulary

Keywords

diffusion-models optimal-transport tmlr
Last synced: 4 months ago · JSON representation ·

Repository

Code for 'Transport with Support: Data-Conditional Diffusion Bridges'

Basic Info
  • Host: GitHub
  • Owner: AaltoML
  • License: mit
  • Language: Python
  • Default Branch: main
  • Homepage:
  • Size: 1.17 MB
Statistics
  • Stars: 4
  • Watchers: 1
  • Forks: 0
  • Open Issues: 1
  • Releases: 0
Topics
diffusion-models optimal-transport tmlr
Created about 2 years ago · Last pushed about 2 years ago
Metadata Files
Readme License Citation

README.md

Iterative Smoothing Bridge

This repository is the official implementation of the methods in the publication:

  • Ella Tamir, Martin Trapp, and Arno Solin (2023). Transport with support: Data-conditional diffusion bridges. In Transactions on Machine Learning Research (TMLR). [arXiv preprint]

The Iterative Smoothing Bridge

The dynamic Schrödinger bridge problem provides an appealing setting for solving constrained time-series data generation tasks posed as optimal transport problems. It consists of learning non-linear diffusion processes using efficient iterative solvers. Recent works have demonstrated state-of-the-art results (e.g., in modelling single-cell embryo RNA sequences or sampling from complex posteriors) but are limited to learning bridges with only initial and terminal constraints. Our work extends this paradigm by proposing the Iterative Smoothing Bridge (ISB). We integrate Bayesian filtering and optimal control into learning the diffusion process, enabling the generation of constrained stochastic processes governed by sparse observations at intermediate stages and terminal constraints. We assess the effectiveness of our method on synthetic and real-world data generation tasks and show that the ISB generalises well to high-dimensional data, is computationally efficient, and provides accurate estimates of the marginals at intermediate and terminal times.

Create environment

In the root of this repo, run conda env create --file=isb-env.yaml conda activate isb-env pip install -e ./src

Running experiments

The experiments scripts are in the folder tasks. To run (non-image) experiments, run the Python file tasks/isb_tasks/run_iterative_smoother.py. To successfully run an experiment, follow these steps:

Running ISB on flat data

  • If you are using a zero drift initialization, no IPFP run is necessary. If you need an unconstrained Schrödinger bridge model as a reference, run IPFP on flat data first
  • Run conda activate isb-env python3 tasks/isb_tasks/run_iterative_smoother.py --config-name circle_isb where the config name should match the desired config file.

Running ISB on MNIST

  • If you are using a zero drift initialization, no IPFP run is necessary. If you choose to use a NN drift initialization, set the model name of a trained model here after completing the steps in "Running IPFP on MNIST"
  • Run conda activate isb-env python3 tasks/isb_tasks/run_iterative_smoother_img.py --config-name mnist_isb We recommend using a GPU for this experiment.

Running IPFP on flat data

  • Run conda activate isb-env python3 tasks/bridge_tasks/run_iftp_flat.py --config-name circle2d where the config name should match the desired config file.

Running IPFP on image data

  • Run conda activate isb-env python3 tasks/bridge_tasks/run_iftp_image.py --config-name mnist We recommend using a GPU for this experiment.

List of data sets

To train an ISB model for a data set, first check the list below for the config files and modify them as you like

Config files

  • 2D sklearn circle: Bridge config in configs/bridge/circles2d.yaml, ISB config in configs/isb/circle_isb.yaml
  • Single-cell: No bridge config, since pre-training an OT model is not necessary here, ISB config in configs/isb/rna_isb.yaml
  • MNIST: Bridge config in configs/bridge/mnist.yaml, ISB config in configs/isb/mnist_isb.yaml

Output

Running the scripts creates model files and plots, below is a description of the folder contents: - bridge_models: Unconstrained model files, to be loaded in ISB training if necessary - isb_models: Trained ISB models - plots: ordered by dataset name and date, includes videos of the learned dynamics and pickled objects from the trained model. The video starting with "finaltrajectoryfinal_particles" shows the result after learning. - outputs: hydra outputs, saves the config file used for each run

Acknowledgements

We wish to thank Adrien Corenflos for sharing an implementation of differentiable resampling in PyTorch (see transport_map.py and sinkhorn.py).

License

This software is provided under the MIT License.

Owner

  • Name: AaltoML
  • Login: AaltoML
  • Kind: organization
  • Location: Finland

Machine learning group at Aalto University lead by Prof. Solin

Citation (CITATION.cff)

cff-version: 1.2.0
message: "This code implements the methods in this original publication."
authors:
  - family-names: "Tamir"
    given-names: "Ella"
  - family-names: "Trapp"
    given-names: "Martin"
    orcid: "https://orcid.org/0000-0003-1725-3381"
  - family-names: "Solin"
    given-names: "Arno"
    orcid: "https://orcid.org/0000-0002-0958-7886"
title: "Iterative Smoothing Bridge"
date-released: 2023-11-21
url: "https://github.com/AaltoML/iterative-smoothing-bridge"
preferred-citation:
  type: article
  authors:
  - family-names: "Tamir"
    given-names: "Ella"
  - family-names: "Trapp"
    given-names: "Martin"
    orcid: "https://orcid.org/0000-0003-1725-3381"
  - family-names: "Solin"
    given-names: "Arno"
    orcid: "https://orcid.org/0000-0002-0958-7886"
  journal: "Transactions on Machine Learning Research (TMLR)"
  title: "Transport with Support: Data-Conditional Diffusion Bridges"
  month: 11
  year: 2023

GitHub Events

Total
Last Year

Dependencies

src/setup.py pypi