pyrddlgym-jax

JAX compilation of RDDL description files, and a differentiable planner in JAX.

https://github.com/pyrddlgym-project/pyrddlgym-jax

Science Score: 67.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
    Found 1 DOI reference(s) in README
  • Academic publication links
    Links to: arxiv.org, sciencedirect.com, springer.com
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (15.6%) to scientific vocabulary

Keywords

automatic-differentiation backpropagation control controller differentiable-simulations gradient-based-optimisation jax model-based-control nonlinear-control nonlinear-dynamics nonlinear-optimization planning planning-algorithms planning-domain-definition-language policy-gradient rddl reinforcement-learning sgd sgd-optimizer stochastic-gradient-descent
Last synced: 4 months ago · JSON representation ·

Repository

JAX compilation of RDDL description files, and a differentiable planner in JAX.

Basic Info
Statistics
  • Stars: 8
  • Watchers: 3
  • Forks: 1
  • Open Issues: 0
  • Releases: 14
Topics
automatic-differentiation backpropagation control controller differentiable-simulations gradient-based-optimisation jax model-based-control nonlinear-control nonlinear-dynamics nonlinear-optimization planning planning-algorithms planning-domain-definition-language policy-gradient rddl reinforcement-learning sgd sgd-optimizer stochastic-gradient-descent
Created almost 2 years ago · Last pushed 5 months ago
Metadata Files
Readme License Citation

README.md

pyRDDLGym-jax

Python Version PyPI Version Documentation Status License: MIT Cumulative PyPI Downloads

Installation | Run cmd | Run python | Configuration | Dashboard | Tuning | Simulation | Citing

pyRDDLGym-jax (or JaxPlan) is an efficient gradient-based planning algorithm based on JAX.

Purpose:

  1. automatic translation of RDDL description files into differentiable JAX simulators
  2. implementation of (highly configurable) operator relaxations for working in discrete and hybrid domains
  3. flexible policy representations and automated Bayesian hyper-parameter tuning
  4. interactive dashboard for dyanmic visualization and debugging
  5. hybridization with parameter-exploring policy gradients.

Some demos of solved problems by JaxPlan:

[!WARNING]
Starting in version 1.0 (major release), the weight parameter in the config file was removed, and was moved to the individual logic components which have their own unique weight parameter assigned. Furthermore, the tuning module has been redesigned from the ground up, and supports tuning arbitrary hyper-parameters via config templates! Finally, the terrible visualizer for the planner was removed and replaced with an interactive real-time dashboard (similar to tensorboard, but custom designed for the planner)!

[!NOTE]
While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!). If you find it is not making progress, check out the PROST planner (for discrete spaces) or the deep reinforcement learning wrappers.

Installation

To install the bare-bones version of JaxPlan with minimum installation requirements:

shell pip install pyRDDLGym-jax

To install JaxPlan with the automatic hyper-parameter tuning and rddlrepository:

shell pip install pyRDDLGym-jax[extra]

(Since version 1.0) To install JaxPlan with the visualization dashboard:

shell pip install pyRDDLGym-jax[dashboard]

(Since version 1.0) To install JaxPlan with all options:

shell pip install pyRDDLGym-jax[extra,dashboard]

Running from the Command Line

A basic run script is provided to train JaxPlan on any RDDL problem:

shell jaxplan plan <domain> <instance> <method> --episodes <episodes>

where: - domain is the domain identifier as specified in rddlrepository (i.e. WildfireMDPippc2014), or a path pointing to a valid domain.rddl file - instance is the instance identifier (i.e. 1, 2, ... 10), or a path pointing to a valid instance.rddl file - method is the planning method to use (i.e. drp, slp, replan) or a path to a valid .cfg file (see section below) - episodes is the (optional) number of episodes to evaluate the learned policy.

The method parameter supports four possible modes: - slp is the basic straight line planner described in this paper - drp is the deep reactive policy network described in this paper - replan is the same as slp except the plan is recalculated at every decision time step - any other argument is interpreted as a file path to a valid configuration file.

For example, the following will train JaxPlan on the Quadcopter domain with 4 drones (with default config):

shell jaxplan plan Quadcopter 1 slp

Running from Another Python Application

To run JaxPlan from within a Python application, refer to the following example:

```python import pyRDDLGym from pyRDDLGym_jax.core.planner import JaxBackpropPlanner, JaxOfflineController

set up the environment (note the vectorized option must be True)

env = pyRDDLGym.make("domain", "instance", vectorized=True)

create the planning algorithm

planner = JaxBackpropPlanner(rddl=env.model, *planner_args) controller = JaxOfflineController(planner, *train_args)

evaluate the planner

controller.evaluate(env, episodes=1, verbose=True, render=True) env.close() ```

Here, we have used the straight-line controller, although you can configure the combination of planner and policy representation if you wish. All controllers are instances of pyRDDLGym's BaseAgent class, so they provide the evaluate() function to streamline interaction with the environment. The **planner_args and **train_args are keyword argument parameters to pass during initialization, but we strongly recommend creating and loading a config file as discussed in the next section.

Configuring the Planner

The simplest way to configure the planner is to write and pass a configuration file with the necessary hyper-parameters. The basic structure of a configuration file is provided below for a straight-line planner:

```ini [Model] logic='FuzzyLogic' comparisonkwargs={'weight': 20} roundingkwargs={'weight': 20} control_kwargs={'weight': 20}

[Optimizer] method='JaxStraightLinePlan' methodkwargs={} optimizer='rmsprop' optimizerkwargs={'learning_rate': 0.001}

[Training] key=42 epochs=5000 train_seconds=30 ```

The configuration file contains three sections: - [Model] specifies the fuzzy logic operations used to relax discrete operations to differentiable approximations; the weight dictates the quality of the approximation, and tnorm specifies the type of fuzzy logic for relacing logical operations in RDDL (e.g. ProductTNorm, GodelTNorm, LukasiewiczTNorm) - [Optimizer] generally specify the optimizer and plan settings; the method specifies the plan/policy representation (e.g. JaxStraightLinePlan, JaxDeepReactivePolicy), the gradient descent settings, learning rate, batch size, etc. - [Training] specifies computation limits, such as total training time and number of iterations, and options for printing or visualizing information from the planner.

For a policy network approach, simply change the [Optimizer] settings like so:

ini ... [Optimizer] method='JaxDeepReactivePolicy' method_kwargs={'topology': [128, 64], 'activation': 'tanh'} ...

The configuration file must then be passed to the planner during initialization. For example, the previous script here can be modified to set parameters from a config file:

```python from pyRDDLGymjax.core.planner import loadconfig

load the config file with planner settings

plannerargs, _, trainargs = load_config("/path/to/config.cfg")

create the planning algorithm

planner = JaxBackpropPlanner(rddl=env.model, *planner_args) controller = JaxOfflineController(planner, *train_args) ... ```

JaxPlan Dashboard

Since version 1.0, JaxPlan has an optional dashboard that allows keeping track of the planner performance across multiple runs, and visualization of the policy or model, and other useful debugging features. To run the dashboard, add the following to your config file:

ini ... [Training] dashboard=True ...

Tuning the Planner

A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:

shell jaxplan tune <domain> <instance> <method> --trials <trials> --iters <iters> --workers <workers> --dashboard <dashboard> --filepath <filepath>

where: - domain is the domain identifier as specified in rddlrepository - instance is the instance identifier - method is the planning method to use (i.e. drp, slp, replan) - trials is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting - iters is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform - workers is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = iters * workers - dashboard is whether the optimizations are tracked in the dashboard application - filepath is the optional file path where a config file with the best hyper-parameter setting will be saved.

It is easy to tune a custom range of the planner's hyper-parameters efficiently. First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:

```ini [Model] logic='FuzzyLogic' comparisonkwargs={'weight': TUNABLEWEIGHT} roundingkwargs={'weight': TUNABLEWEIGHT} controlkwargs={'weight': TUNABLEWEIGHT}

[Optimizer] method='JaxStraightLinePlan' methodkwargs={} optimizer='rmsprop' optimizerkwargs={'learningrate': TUNABLELEARNING_RATE}

[Training] trainseconds=30 printsummary=False printprogress=False trainon_reset=True ```

would allow to tune the sharpness of model relaxations, and the learning rate of the optimizer.

Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand, and run the optimizer:

```python import pyRDDLGym from pyRDDLGym_jax.core.tuning import JaxParameterTuning, Hyperparameter

set up the environment

env = pyRDDLGym.make(domain, instance, vectorized=True)

load the config file template with planner settings

with open('path/to/config.cfg', 'r') as file: config_template = file.read()

tune weight from 10^-1 ... 10^5 and lr from 10^-5 ... 10^1

def power10(x): return 10.0 ** x
hyperparams = [Hyperparameter('TUNABLE
WEIGHT', -1., 5., power10), Hyperparameter('TUNABLELEARNINGRATE', -5., 1., power10)]

build the tuner and tune

tuning = JaxParameterTuning(env=env, configtemplate=configtemplate, hyperparams=hyperparams, online=False, evaltrials=trials, numworkers=workers, gpiters=iters) tuning.tune(key=42, logfile='path/to/log.csv') ```

Simulation

The JAX compiler can be used as a backend for simulating and evaluating RDDL environments:

```python import pyRDDLGym from pyRDDLGym.core.policy import RandomAgent from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator

create the environment

env = pyRDDLGym.make("domain", "instance", backend=JaxRDDLSimulator)

evaluate the random policy

agent = RandomAgent(actionspace=env.actionspace, numactions=env.maxallowed_actions) agent.evaluate(env, verbose=True, render=True) ```

For some domains, the JAX backend could perform better than the numpy-based one, due to various compiler optimizations. In any event, the simulation results using the JAX backend should (almost) always match the numpy backend.

Citing JaxPlan

The following citation describes the main ideas of JaxPlan. Please cite it if you found it useful:

@inproceedings{gimelfarb2024jaxplan, title={JaxPlan and GurobiPlan: Optimization Baselines for Replanning in Discrete and Mixed Discrete and Continuous Probabilistic Domains}, author={Michael Gimelfarb and Ayal Taitler and Scott Sanner}, booktitle={34th International Conference on Automated Planning and Scheduling}, year={2024}, url={https://openreview.net/forum?id=7IKtmUpLEH} }

Some of the implementation details derive from the following literature, which you may wish to also cite in your research papers: - A Distributional Framework for Risk-Sensitive End-to-End Planning in Continuous MDPs, AAAI 2022 - Deep reactive policies for planning in stochastic nonlinear domains, AAAI 2019 - Stochastic Planning with Lifted Symbolic Trajectory Optimization, AAAI 2019 - Scalable planning with tensorflow for hybrid nonlinear domains, NeurIPS 2017 - Baseline-Free Sampling in Parameter Exploring Policy Gradients: Super Symmetric PGPE, ANN 2015

The model relaxations in JaxPlan are based on the following works: - Poisson Variational Autoencoder, NeurIPS 2025 - Analyzing Differentiable Fuzzy Logic Operators, AI 2022 - Learning with algorithmic supervision via continuous relaxations, NeurIPS 2021 - Universally quantized neural compression, NeurIPS 2020 - Generalized Gumbel-Softmax Gradient Estimator for Generic Discrete Random Variables, 2020 - Categorical Reparametrization with Gumbel-Softmax, ICLR 2017

Owner

  • Name: pyrddlgym-project
  • Login: pyrddlgym-project
  • Kind: organization

The official pyRDDLGym Simulator, and everything RDDL related

Citation (CITATION.cff)

cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- family-names: "Gimelfarb"
  given-names: "Michael"
title: "pyRDDLGym-jax"
version: 1.0
date-released: 2024-01-01
preferred-citation:
  type: conference-paper
  authors:
  - family-names: "Gimelfarb"
    given-names: "Michael"
  - family-names: "Taitler"
    given-names: "Ayal"
  - family-names: "Sanner"
    given-names: "Scott"
  title: "JaxPlan and GurobiPlan: Optimization Baselines for Replanning in Discrete and Mixed Discrete and Continuous Probabilistic Domains"
  journal: "Proceedings of the International Conference on Automated Planning and Scheduling"
  url: "https://openreview.net/forum?id=7IKtmUpLEH"
  month: 5
  day: 30
  year: 2024
  volume: 34
  start: 230
  end: 238

GitHub Events

Total
  • Create event: 24
  • Release event: 11
  • Issues event: 14
  • Watch event: 3
  • Issue comment event: 13
  • Push event: 266
  • Pull request review comment event: 2
  • Pull request review event: 3
  • Pull request event: 20
Last Year
  • Create event: 24
  • Release event: 11
  • Issues event: 14
  • Watch event: 3
  • Issue comment event: 13
  • Push event: 266
  • Pull request review comment event: 2
  • Pull request review event: 3
  • Pull request event: 20

Issues and Pull Requests

Last synced: 4 months ago

All Time
  • Total issues: 5
  • Total pull requests: 5
  • Average time to close issues: 3 days
  • Average time to close pull requests: 2 minutes
  • Total issue authors: 3
  • Total pull request authors: 2
  • Average comments per issue: 1.0
  • Average comments per pull request: 0.0
  • Merged pull requests: 4
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 5
  • Pull requests: 5
  • Average time to close issues: 3 days
  • Average time to close pull requests: 2 minutes
  • Issue authors: 3
  • Pull request authors: 2
  • Average comments per issue: 1.0
  • Average comments per pull request: 0.0
  • Merged pull requests: 4
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • nhuet (3)
  • mike-gimelfarb (2)
  • AlecDong (2)
  • danielbdias (2)
  • iliathesmirnov (1)
Pull Request Authors
  • mike-gimelfarb (8)
  • ataitler (1)
  • danielbdias (1)
Top Labels
Issue Labels
documentation (1) bug (1)
Pull Request Labels
enhancement (4) major release (1)

Packages

  • Total packages: 1
  • Total downloads:
    • pypi 1,230 last-month
  • Total dependent packages: 0
  • Total dependent repositories: 0
  • Total versions: 16
  • Total maintainers: 1
pypi.org: pyrddlgym-jax

pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.

  • Versions: 16
  • Dependent Packages: 0
  • Dependent Repositories: 0
  • Downloads: 1,230 Last month
Rankings
Dependent packages count: 9.8%
Average: 37.4%
Dependent repos count: 65.0%
Maintainers (1)
Last synced: 4 months ago

Dependencies

requirements.txt pypi
  • bayesian-optimization *
  • dm-haiku >=0.0.9
  • gym >=0.24.0
  • jax >=0.3.25
  • matplotlib >=3.5.0
  • numpy >=1.22
  • optax >=0.1.4
  • pillow >=9.2.0
  • ply *
  • pygame *
  • tensorflow >=2.11.0
  • tensorflow-probability >=0.19.0
  • tqdm *
setup.py pypi
  • bayesian-optimization *
  • dm-haiku >=0.0.9
  • jax >=0.3.25
  • optax >=0.1.4
  • pyRDDLGym >=2.0.0
  • rddlrepository *
  • tensorflow >=2.11.0
  • tensorflow-probability >=0.19.0
  • tqdm *