https://github.com/anindex/ssax
Implementation of Sinkhorn Step in JAX, NeurIPS 2023.
Science Score: 13.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
○CITATION.cff file
-
✓codemeta.json file
Found codemeta.json file -
○.zenodo.json file
-
○DOI references
-
○Academic publication links
-
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (14.1%) to scientific vocabulary
Keywords
Repository
Implementation of Sinkhorn Step in JAX, NeurIPS 2023.
Basic Info
- Host: GitHub
- Owner: anindex
- License: mit
- Language: Python
- Default Branch: main
- Homepage: https://sites.google.com/view/sinkhorn-step/
- Size: 19 MB
Statistics
- Stars: 42
- Watchers: 2
- Forks: 3
- Open Issues: 0
- Releases: 0
Topics
Metadata Files
README.md
Sinkhorn Step in JAX (ssax)
This ssax repository demonstrates the proof of concept for the Sinkhorn Step - a batch gradient-free optimizer for highly non-convex objectives in JAX. ssax is heavily inspired by the code structure of OTT-JAX to utilize most of its linear solvers, enabling the users to easily switch between different solver flavors.
NOTE: This repository implements the Sinkhorn Step optimizer (in JAX) as a general-purpose standalone solver for non-convex optimization problems. The MPOT trajectory optimizer using Sinkhorn Step (in PyTorch) is released in this repository mpot.
Paper Preprint
This work has been accepted to NeurIPS 2023. Please find the pre-print here:
Installation
Simply activate your conda/Python environment and run
azure
pip install -e .
Please install JAX with CUDA support if you want to run the code on GPU for more performance.
Run some demos
An example script is provided in scripts/example.py.
For testing Sinkhorn Step with various synthetic functions, run the following script with hydra settings:
azure
python scripts/run.py experiment=ss-al
and find result animations in the logs/ folder. You can replace the tag experiment=<exp-filename> with filenames found in configs/experiment folder. The current available optimization experiments are:
ss-al: Ackley function in 2Dss-al-10d: Ackley function in 10Dss-bk: Bukin function in 2Dss-dw: DropWave function in 2Dss-eh: EggHolder function in 2Dss-ht: Hoelder Table function in 2Dss-lv: Levy function in 2Dss-rb: Rosenbrock function in 2Dss-rg: Rastrigin function in 2Dss-st: Styblinski-Tang function in 2Dss-st-10d: Styblinski-Tang function in 10D
Note: For tuning new settings, the most sensitive hyperparameters are step_radius, probe_radius, entropic regularization scalar ent_epsilon and the step-annealing scheme epsilon_scheduler parameters. You can play around with these parameters together with the other hyperparameters with synthetic functions to get a feeling of how they affect the optimization. For the 10D experiments, the plots are projected to the first 2 dimensions for visualization.
We also add some benchmarks on gradient approximation experiments based on cosine similarity between the Sinkhorn Step and the true gradient, over outer iterations and over entropic regularization on the Sinkhorn distance. We turn off step annealing for benchmarking purpose. The current available gradient approximation experiments are:
ss-al-cosin-sim: Ackley function in 10Dss-st-cosin-sim: Styblinski-Tang function in 10D
To run them:
azure
python scripts/benchmark_cosin_similarity_single.py experiment=ss-st-cosin-sim num_seeds=20
Citation
If you found this work useful, please consider citing this reference:
@inproceedings{le2023accelerating,
title={Accelerating Motion Planning via Optimal Transport},
author={Le, An T. and Chalvatzaki, Georgia and Biess, Armin and Peters, Jan},
booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
year={2023}
}
See also
- The OTT-JAX documentation for more details on the linear solvers.
- This wonderful library of synthetic test functions from Sonja & Derek (SFU).
Owner
- Name: An Thai Le
- Login: anindex
- Kind: user
- Location: Germany
- Company: Technische Universität Darmstadt
- Website: https://www.ias.informatik.tu-darmstadt.de/Team/AnThaiLe
- Twitter: an_thai_le
- Repositories: 3
- Profile: https://github.com/anindex
GitHub Events
Total
- Watch event: 2
- Push event: 1
Last Year
- Watch event: 2
- Push event: 1
Dependencies
- flax *
- hydra-core *
- jax *
- matplotlib *
- numpy *
- omegaconf *
- ott-jax *
- setuptools *