https://github.com/anindex/ssax

Implementation of Sinkhorn Step in JAX, NeurIPS 2023.

https://github.com/anindex/ssax

Science Score: 13.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
  • DOI references
  • Academic publication links
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (14.1%) to scientific vocabulary

Keywords

batch optimal-transport optimization
Last synced: 5 months ago · JSON representation

Repository

Implementation of Sinkhorn Step in JAX, NeurIPS 2023.

Basic Info
Statistics
  • Stars: 42
  • Watchers: 2
  • Forks: 3
  • Open Issues: 0
  • Releases: 0
Topics
batch optimal-transport optimization
Created over 2 years ago · Last pushed 11 months ago
Metadata Files
Readme License

README.md

Sinkhorn Step in JAX (ssax)

This ssax repository demonstrates the proof of concept for the Sinkhorn Step - a batch gradient-free optimizer for highly non-convex objectives in JAX. ssax is heavily inspired by the code structure of OTT-JAX to utilize most of its linear solvers, enabling the users to easily switch between different solver flavors.

NOTE: This repository implements the Sinkhorn Step optimizer (in JAX) as a general-purpose standalone solver for non-convex optimization problems. The MPOT trajectory optimizer using Sinkhorn Step (in PyTorch) is released in this repository mpot.

Paper Preprint

This work has been accepted to NeurIPS 2023. Please find the pre-print here:

Installation

Simply activate your conda/Python environment and run

azure pip install -e .

Please install JAX with CUDA support if you want to run the code on GPU for more performance.

Run some demos

An example script is provided in scripts/example.py.

For testing Sinkhorn Step with various synthetic functions, run the following script with hydra settings:

azure python scripts/run.py experiment=ss-al

and find result animations in the logs/ folder. You can replace the tag experiment=<exp-filename> with filenames found in configs/experiment folder. The current available optimization experiments are:

  • ss-al: Ackley function in 2D
  • ss-al-10d: Ackley function in 10D
  • ss-bk: Bukin function in 2D
  • ss-dw: DropWave function in 2D
  • ss-eh: EggHolder function in 2D
  • ss-ht: Hoelder Table function in 2D
  • ss-lv: Levy function in 2D
  • ss-rb: Rosenbrock function in 2D
  • ss-rg: Rastrigin function in 2D
  • ss-st: Styblinski-Tang function in 2D
  • ss-st-10d: Styblinski-Tang function in 10D

Note: For tuning new settings, the most sensitive hyperparameters are step_radius, probe_radius, entropic regularization scalar ent_epsilon and the step-annealing scheme epsilon_scheduler parameters. You can play around with these parameters together with the other hyperparameters with synthetic functions to get a feeling of how they affect the optimization. For the 10D experiments, the plots are projected to the first 2 dimensions for visualization.

We also add some benchmarks on gradient approximation experiments based on cosine similarity between the Sinkhorn Step and the true gradient, over outer iterations and over entropic regularization on the Sinkhorn distance. We turn off step annealing for benchmarking purpose. The current available gradient approximation experiments are:

  • ss-al-cosin-sim: Ackley function in 10D
  • ss-st-cosin-sim: Styblinski-Tang function in 10D

To run them: azure python scripts/benchmark_cosin_similarity_single.py experiment=ss-st-cosin-sim num_seeds=20

Citation

If you found this work useful, please consider citing this reference:

@inproceedings{le2023accelerating, title={Accelerating Motion Planning via Optimal Transport}, author={Le, An T. and Chalvatzaki, Georgia and Biess, Armin and Peters, Jan}, booktitle={Advances in Neural Information Processing Systems (NeurIPS)}, year={2023} }

See also

Owner

  • Name: An Thai Le
  • Login: anindex
  • Kind: user
  • Location: Germany
  • Company: Technische Universität Darmstadt

GitHub Events

Total
  • Watch event: 2
  • Push event: 1
Last Year
  • Watch event: 2
  • Push event: 1

Dependencies

requirements.txt pypi
  • flax *
  • hydra-core *
  • jax *
  • matplotlib *
  • numpy *
  • omegaconf *
  • ott-jax *
  • setuptools *
setup.py pypi