https://github.com/aaltoml/scalable-inference-in-sdes

Methods and experiments for assumed density SDE approximations

https://github.com/aaltoml/scalable-inference-in-sdes

Science Score: 10.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
  • .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (12.8%) to scientific vocabulary
Last synced: 10 months ago · JSON representation

Repository

Methods and experiments for assumed density SDE approximations

Basic Info
  • Host: GitHub
  • Owner: AaltoML
  • License: mit
  • Language: Jupyter Notebook
  • Default Branch: master
  • Homepage:
  • Size: 398 KB
Statistics
  • Stars: 10
  • Watchers: 2
  • Forks: 2
  • Open Issues: 1
  • Releases: 0
Created over 4 years ago · Last pushed over 4 years ago

https://github.com/AaltoML/scalable-inference-in-sdes/blob/master/

# Scalable Inference in SDEs by Direct Matching of the FokkerPlanckKolmogorov Equation

This repository is the official implementation of the methods in the publication
* Solin, A., Tamir, E., and Verma, P. (2021) **Scalable Inference in SDEs by Direct Matching of the FokkerPlanckKolmogorov Equation**. In *Advances in Neural Information Processing Systems 35 (NeurIPS)*. [[arXiv]](https://arxiv.org/abs/2110.15739)

In the paper, we advocate alternative solution concepts to stochastic differential equation (SDE) models in machine learning, where simulation-based techniques such as variants of stochastic RungeKutta are currently the *de facto* approach. These methods are convenient, general-purpose, and used with parametric and non-parametric models, and neural SDEs. Yet, stochastic RungeKutta relies on the use of sampling schemes that can be inefficient in high dimensions. We address this issue by revisiting the classical SDE literature and derive direct approximations to the (typically intractable) FokkerPlanckKolmogorov equation by matching moments. The codebase in this repository includes the building blocks for the figures and code for the experiments in the paper.

## Python environment

The code should be run using python 3.6. If you are already using python 3.6, dependencies can be installed using the requirements file
```bash
pip install -r requirements.txt
```
Alternatively, conda virtual environment can be created using the `environment.yml` file

```bash
conda env create -f environment.yml
conda activate scalable-sde
```

## MOCAP experiment

The MOCAP experiment specific code is in `experiments/mocap`. To prepare the data for the experiment, place
the Mocap matlab data file `mocap35.mat` to folder `[base_folder]/data/mocap_data`, where `base_folder` is to be given as input to the
training and test scripts. We use the preprocessed MOCAP data from https://github.com/cagatayyildiz/ODE2VAE.

In order to run the MOCAP training, run
 ```bash
python experiments/mocap/walking_tf_train.py [-base_folder BASE_FOLDER] [-task TASK] [-decoder_dist DECODER_DIST]
                  [-model_name MODEL_NAME] [-prior_model_name PRIOR_MODEL_NAME] [-vae_name VAE_NAME]
                  [--dt DT] [--latent_dim LATENT_DIM] [--context_dim CONTEXT_DIM]
                  [--epochs EPOCHS] [--start_len START_LEN]
````


For testing a trained MOCAP model, run
 ```bash
python experiments/mocap/walking_tf_test.py [-base_folder BASE_FOLDER] [-task TASK] [-decoder_dist DECODER_DIST]
                  [-model_name MODEL_NAME] [-vae_name VAE_NAME] [--dt DT]
                  [--latent_dim LATENT_DIM] [--context_dim CONTEXT_DIM] [--start_len START_LEN]
```

See the train and test scripts for further documentation of their input arguments.
To modify the codebase for some other flat dataset (VAE implementation doesn't support image data),
modify the utility function `get_data` in `experiments/mocap/walking_tf_functions.py` to output another dataset class.



### Alternative SDE Approximations
You can run the MOCAP experiment with any new SDE approximator, as long as it inherits from the class
`SDEApprox` in `src/sde_tf/sde_approx/sde_approx.py`.

## Rotating MNIST

The code for the rotating MNIST experiment is available in `experiments/mnist`. In order to run the experiment:

```bash
cd experiments/mnist/
python main.py
```

All the experiment related parameters are present in `config.py` from where they can be modified. By default the output folder is `experiments/mnist/output` where the trained models and inference plots are saved.

## Notebooks

The code used to generate `Figure 3` and `Figure 4` of the paper is available in the jupyter notebook, `/notebooks/`.

## Citation
If you use the code in this repository for your research, please cite the paper as follows:
```bibtex
@inproceedings{solin2021,
  title={Scalable Inference in SDEs by Direct Matching of the {F}okker--{P}lanck--{K}olmogorov Equation},
  author={Solin, Arno and Tamir, Ella and Verma, Prakhar},
  booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
  year={2021}
}
```

## Contributing

For all correspondence, please contact arno.solin@aalto.fi, ella.tamir@aalto.fi, or prakhar.verma@aalto.fi .

## License

This software is provided under the [MIT license](LICENSE).

Owner

  • Name: AaltoML
  • Login: AaltoML
  • Kind: organization
  • Location: Finland

Machine learning group at Aalto University lead by Prof. Solin

GitHub Events

Total
  • Watch event: 1
  • Issue comment event: 1
  • Fork event: 1
Last Year
  • Watch event: 1
  • Issue comment event: 1
  • Fork event: 1

Dependencies

environment.yml pypi
  • gpflow ==2.3.0
  • scikit-image ==0.15.0
  • scikit-learn ==0.21.3
  • scipy ==1.5.4
  • tensorflow ==2.4
  • tensorflow-estimator ==2.4.0
  • tensorflow-probability ==0.11.0
  • torch ==1.9.1
  • torchsummary ==1.5.1
  • torchvision ==0.10.1
requirements.txt pypi
  • gpflow ==2.3.0
  • jupyter ==1.0.0
  • matplotlib ==3.1.1
  • numpy ==1.19.5
  • scikit-image ==0.15.0
  • scikit-learn ==0.21.3
  • scipy ==1.5.4
  • tensorflow ==2.4
  • tensorflow-estimator ==2.4.0
  • tensorflow-probability ==0.11.0
  • torch ==1.9.1
  • torchsummary ==1.5.1
  • torchvision ==0.10.1