pyrddlgym-jax
JAX compilation of RDDL description files, and a differentiable planner in JAX.
Science Score: 67.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
✓DOI references
Found 1 DOI reference(s) in README -
✓Academic publication links
Links to: arxiv.org, sciencedirect.com, springer.com -
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (15.6%) to scientific vocabulary
Keywords
Repository
JAX compilation of RDDL description files, and a differentiable planner in JAX.
Basic Info
- Host: GitHub
- Owner: pyrddlgym-project
- License: mit
- Language: Python
- Default Branch: main
- Homepage: https://pyrddlgym.readthedocs.io/en/latest/jax.html
- Size: 34.6 MB
Statistics
- Stars: 8
- Watchers: 3
- Forks: 1
- Open Issues: 0
- Releases: 14
Topics
Metadata Files
README.md
pyRDDLGym-jax
Installation | Run cmd | Run python | Configuration | Dashboard | Tuning | Simulation | Citing
pyRDDLGym-jax (or JaxPlan) is an efficient gradient-based planning algorithm based on JAX.
Purpose:
- automatic translation of RDDL description files into differentiable JAX simulators
- implementation of (highly configurable) operator relaxations for working in discrete and hybrid domains
- flexible policy representations and automated Bayesian hyper-parameter tuning
- interactive dashboard for dyanmic visualization and debugging
- hybridization with parameter-exploring policy gradients.
Some demos of solved problems by JaxPlan:
[!WARNING]
Starting in version 1.0 (major release), theweightparameter in the config file was removed, and was moved to the individual logic components which have their own unique weight parameter assigned. Furthermore, the tuning module has been redesigned from the ground up, and supports tuning arbitrary hyper-parameters via config templates! Finally, the terrible visualizer for the planner was removed and replaced with an interactive real-time dashboard (similar to tensorboard, but custom designed for the planner)![!NOTE]
While JaxPlan can support some discrete state/action problems through model relaxations, on some discrete problems it can perform poorly (though there is an ongoing effort to remedy this!). If you find it is not making progress, check out the PROST planner (for discrete spaces) or the deep reinforcement learning wrappers.
Installation
To install the bare-bones version of JaxPlan with minimum installation requirements:
shell
pip install pyRDDLGym-jax
To install JaxPlan with the automatic hyper-parameter tuning and rddlrepository:
shell
pip install pyRDDLGym-jax[extra]
(Since version 1.0) To install JaxPlan with the visualization dashboard:
shell
pip install pyRDDLGym-jax[dashboard]
(Since version 1.0) To install JaxPlan with all options:
shell
pip install pyRDDLGym-jax[extra,dashboard]
Running from the Command Line
A basic run script is provided to train JaxPlan on any RDDL problem:
shell
jaxplan plan <domain> <instance> <method> --episodes <episodes>
where:
- domain is the domain identifier as specified in rddlrepository (i.e. WildfireMDPippc2014), or a path pointing to a valid domain.rddl file
- instance is the instance identifier (i.e. 1, 2, ... 10), or a path pointing to a valid instance.rddl file
- method is the planning method to use (i.e. drp, slp, replan) or a path to a valid .cfg file (see section below)
- episodes is the (optional) number of episodes to evaluate the learned policy.
The method parameter supports four possible modes:
- slp is the basic straight line planner described in this paper
- drp is the deep reactive policy network described in this paper
- replan is the same as slp except the plan is recalculated at every decision time step
- any other argument is interpreted as a file path to a valid configuration file.
For example, the following will train JaxPlan on the Quadcopter domain with 4 drones (with default config):
shell
jaxplan plan Quadcopter 1 slp
Running from Another Python Application
To run JaxPlan from within a Python application, refer to the following example:
```python import pyRDDLGym from pyRDDLGym_jax.core.planner import JaxBackpropPlanner, JaxOfflineController
set up the environment (note the vectorized option must be True)
env = pyRDDLGym.make("domain", "instance", vectorized=True)
create the planning algorithm
planner = JaxBackpropPlanner(rddl=env.model, *planner_args) controller = JaxOfflineController(planner, *train_args)
evaluate the planner
controller.evaluate(env, episodes=1, verbose=True, render=True) env.close() ```
Here, we have used the straight-line controller, although you can configure the combination of planner and policy representation if you wish.
All controllers are instances of pyRDDLGym's BaseAgent class, so they provide the evaluate() function to streamline interaction with the environment.
The **planner_args and **train_args are keyword argument parameters to pass during initialization, but we strongly recommend creating and loading a config file as discussed in the next section.
Configuring the Planner
The simplest way to configure the planner is to write and pass a configuration file with the necessary hyper-parameters. The basic structure of a configuration file is provided below for a straight-line planner:
```ini [Model] logic='FuzzyLogic' comparisonkwargs={'weight': 20} roundingkwargs={'weight': 20} control_kwargs={'weight': 20}
[Optimizer] method='JaxStraightLinePlan' methodkwargs={} optimizer='rmsprop' optimizerkwargs={'learning_rate': 0.001}
[Training] key=42 epochs=5000 train_seconds=30 ```
The configuration file contains three sections:
- [Model] specifies the fuzzy logic operations used to relax discrete operations to differentiable approximations; the weight dictates the quality of the approximation,
and tnorm specifies the type of fuzzy logic for relacing logical operations in RDDL (e.g. ProductTNorm, GodelTNorm, LukasiewiczTNorm)
- [Optimizer] generally specify the optimizer and plan settings; the method specifies the plan/policy representation (e.g. JaxStraightLinePlan, JaxDeepReactivePolicy), the gradient descent settings, learning rate, batch size, etc.
- [Training] specifies computation limits, such as total training time and number of iterations, and options for printing or visualizing information from the planner.
For a policy network approach, simply change the [Optimizer] settings like so:
ini
...
[Optimizer]
method='JaxDeepReactivePolicy'
method_kwargs={'topology': [128, 64], 'activation': 'tanh'}
...
The configuration file must then be passed to the planner during initialization. For example, the previous script here can be modified to set parameters from a config file:
```python from pyRDDLGymjax.core.planner import loadconfig
load the config file with planner settings
plannerargs, _, trainargs = load_config("/path/to/config.cfg")
create the planning algorithm
planner = JaxBackpropPlanner(rddl=env.model, *planner_args) controller = JaxOfflineController(planner, *train_args) ... ```
JaxPlan Dashboard
Since version 1.0, JaxPlan has an optional dashboard that allows keeping track of the planner performance across multiple runs, and visualization of the policy or model, and other useful debugging features. To run the dashboard, add the following to your config file:
ini
...
[Training]
dashboard=True
...
Tuning the Planner
A basic run script is provided to run automatic Bayesian hyper-parameter tuning for the most sensitive parameters of JaxPlan:
shell
jaxplan tune <domain> <instance> <method> --trials <trials> --iters <iters> --workers <workers> --dashboard <dashboard> --filepath <filepath>
where:
- domain is the domain identifier as specified in rddlrepository
- instance is the instance identifier
- method is the planning method to use (i.e. drp, slp, replan)
- trials is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting
- iters is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform
- workers is the (optional) number of parallel evaluations to be done at each iteration, e.g. the total evaluations = iters * workers
- dashboard is whether the optimizations are tracked in the dashboard application
- filepath is the optional file path where a config file with the best hyper-parameter setting will be saved.
It is easy to tune a custom range of the planner's hyper-parameters efficiently. First create a config file template with patterns replacing concrete parameter values that you want to tune, e.g.:
```ini [Model] logic='FuzzyLogic' comparisonkwargs={'weight': TUNABLEWEIGHT} roundingkwargs={'weight': TUNABLEWEIGHT} controlkwargs={'weight': TUNABLEWEIGHT}
[Optimizer] method='JaxStraightLinePlan' methodkwargs={} optimizer='rmsprop' optimizerkwargs={'learningrate': TUNABLELEARNING_RATE}
[Training] trainseconds=30 printsummary=False printprogress=False trainon_reset=True ```
would allow to tune the sharpness of model relaxations, and the learning rate of the optimizer.
Next, you must link the patterns in the config with concrete hyper-parameter ranges the tuner will understand, and run the optimizer:
```python import pyRDDLGym from pyRDDLGym_jax.core.tuning import JaxParameterTuning, Hyperparameter
set up the environment
env = pyRDDLGym.make(domain, instance, vectorized=True)
load the config file template with planner settings
with open('path/to/config.cfg', 'r') as file: config_template = file.read()
tune weight from 10^-1 ... 10^5 and lr from 10^-5 ... 10^1
def power10(x):
return 10.0 ** x
hyperparams = [Hyperparameter('TUNABLEWEIGHT', -1., 5., power10),
Hyperparameter('TUNABLELEARNINGRATE', -5., 1., power10)]
build the tuner and tune
tuning = JaxParameterTuning(env=env, configtemplate=configtemplate, hyperparams=hyperparams, online=False, evaltrials=trials, numworkers=workers, gpiters=iters) tuning.tune(key=42, logfile='path/to/log.csv') ```
Simulation
The JAX compiler can be used as a backend for simulating and evaluating RDDL environments:
```python import pyRDDLGym from pyRDDLGym.core.policy import RandomAgent from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator
create the environment
env = pyRDDLGym.make("domain", "instance", backend=JaxRDDLSimulator)
evaluate the random policy
agent = RandomAgent(actionspace=env.actionspace, numactions=env.maxallowed_actions) agent.evaluate(env, verbose=True, render=True) ```
For some domains, the JAX backend could perform better than the numpy-based one, due to various compiler optimizations. In any event, the simulation results using the JAX backend should (almost) always match the numpy backend.
Citing JaxPlan
The following citation describes the main ideas of JaxPlan. Please cite it if you found it useful:
@inproceedings{gimelfarb2024jaxplan,
title={JaxPlan and GurobiPlan: Optimization Baselines for Replanning in Discrete and Mixed Discrete and Continuous Probabilistic Domains},
author={Michael Gimelfarb and Ayal Taitler and Scott Sanner},
booktitle={34th International Conference on Automated Planning and Scheduling},
year={2024},
url={https://openreview.net/forum?id=7IKtmUpLEH}
}
Some of the implementation details derive from the following literature, which you may wish to also cite in your research papers: - A Distributional Framework for Risk-Sensitive End-to-End Planning in Continuous MDPs, AAAI 2022 - Deep reactive policies for planning in stochastic nonlinear domains, AAAI 2019 - Stochastic Planning with Lifted Symbolic Trajectory Optimization, AAAI 2019 - Scalable planning with tensorflow for hybrid nonlinear domains, NeurIPS 2017 - Baseline-Free Sampling in Parameter Exploring Policy Gradients: Super Symmetric PGPE, ANN 2015
The model relaxations in JaxPlan are based on the following works: - Poisson Variational Autoencoder, NeurIPS 2025 - Analyzing Differentiable Fuzzy Logic Operators, AI 2022 - Learning with algorithmic supervision via continuous relaxations, NeurIPS 2021 - Universally quantized neural compression, NeurIPS 2020 - Generalized Gumbel-Softmax Gradient Estimator for Generic Discrete Random Variables, 2020 - Categorical Reparametrization with Gumbel-Softmax, ICLR 2017
Owner
- Name: pyrddlgym-project
- Login: pyrddlgym-project
- Kind: organization
- Repositories: 1
- Profile: https://github.com/pyrddlgym-project
The official pyRDDLGym Simulator, and everything RDDL related
Citation (CITATION.cff)
cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- family-names: "Gimelfarb"
given-names: "Michael"
title: "pyRDDLGym-jax"
version: 1.0
date-released: 2024-01-01
preferred-citation:
type: conference-paper
authors:
- family-names: "Gimelfarb"
given-names: "Michael"
- family-names: "Taitler"
given-names: "Ayal"
- family-names: "Sanner"
given-names: "Scott"
title: "JaxPlan and GurobiPlan: Optimization Baselines for Replanning in Discrete and Mixed Discrete and Continuous Probabilistic Domains"
journal: "Proceedings of the International Conference on Automated Planning and Scheduling"
url: "https://openreview.net/forum?id=7IKtmUpLEH"
month: 5
day: 30
year: 2024
volume: 34
start: 230
end: 238
GitHub Events
Total
- Create event: 24
- Release event: 11
- Issues event: 14
- Watch event: 3
- Issue comment event: 13
- Push event: 266
- Pull request review comment event: 2
- Pull request review event: 3
- Pull request event: 20
Last Year
- Create event: 24
- Release event: 11
- Issues event: 14
- Watch event: 3
- Issue comment event: 13
- Push event: 266
- Pull request review comment event: 2
- Pull request review event: 3
- Pull request event: 20
Issues and Pull Requests
Last synced: 4 months ago
All Time
- Total issues: 5
- Total pull requests: 5
- Average time to close issues: 3 days
- Average time to close pull requests: 2 minutes
- Total issue authors: 3
- Total pull request authors: 2
- Average comments per issue: 1.0
- Average comments per pull request: 0.0
- Merged pull requests: 4
- Bot issues: 0
- Bot pull requests: 0
Past Year
- Issues: 5
- Pull requests: 5
- Average time to close issues: 3 days
- Average time to close pull requests: 2 minutes
- Issue authors: 3
- Pull request authors: 2
- Average comments per issue: 1.0
- Average comments per pull request: 0.0
- Merged pull requests: 4
- Bot issues: 0
- Bot pull requests: 0
Top Authors
Issue Authors
- nhuet (3)
- mike-gimelfarb (2)
- AlecDong (2)
- danielbdias (2)
- iliathesmirnov (1)
Pull Request Authors
- mike-gimelfarb (8)
- ataitler (1)
- danielbdias (1)
Top Labels
Issue Labels
Pull Request Labels
Packages
- Total packages: 1
-
Total downloads:
- pypi 1,230 last-month
- Total dependent packages: 0
- Total dependent repositories: 0
- Total versions: 16
- Total maintainers: 1
pypi.org: pyrddlgym-jax
pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
- Homepage: https://github.com/pyrddlgym-project/pyRDDLGym-jax
- Documentation: https://pyrddlgym-jax.readthedocs.io/
- License: MIT License
-
Latest release: 2.6
published 6 months ago
Rankings
Maintainers (1)
Dependencies
- bayesian-optimization *
- dm-haiku >=0.0.9
- gym >=0.24.0
- jax >=0.3.25
- matplotlib >=3.5.0
- numpy >=1.22
- optax >=0.1.4
- pillow >=9.2.0
- ply *
- pygame *
- tensorflow >=2.11.0
- tensorflow-probability >=0.19.0
- tqdm *
- bayesian-optimization *
- dm-haiku >=0.0.9
- jax >=0.3.25
- optax >=0.1.4
- pyRDDLGym >=2.0.0
- rddlrepository *
- tensorflow >=2.11.0
- tensorflow-probability >=0.19.0
- tqdm *