jaxdp
A Dynamic Programming package for discrete MDPs implemented in JAX
Science Score: 54.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
✓Academic publication links
Links to: arxiv.org -
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (9.9%) to scientific vocabulary
Keywords
Repository
A Dynamic Programming package for discrete MDPs implemented in JAX
Basic Info
Statistics
- Stars: 7
- Watchers: 1
- Forks: 1
- Open Issues: 0
- Releases: 1
Topics
Metadata Files
README.md
jaxdp
jaxdp is a Python package providing functional implementations of dynamic programming (DP) algorithms for finite state-action Markov decision processes (MDPs) within the
ecosystem. By leveraging JAX transformations, you can accelerate DP algorithms (including GPU acceleration) through vectorized execution across multiple MDP instances, initial values, and parameters.
Vectorization
jaxdp functions are fully compatible with JAX transformations. They are stateless with memory explicitly provided to functions.
Algorithm Example
The examples directory contains implementations and benchmarks of planning algorithms using jaxdp. Below is a code snippet for Momentum accelerated Value Iteration:
```python """ Momentum accelerated Value Iteration. """ @struct.dataclass class State: qval: jnp.ndarray prevq_val: jnp.ndarray gamma: jnp.ndarray beta: jnp.ndarray alpha: jnp.ndarray
def update(s: State, mdp: MDP, step: int) -> State: diff = s.qval - s.prevqval bresidual = jaxdp.bellmanoptimalityoperator.q(mdp, s.qval, s.gamma) - s.qval nextq = s.qval + s.alpha * b_residual + s.beta * diff
return s.replace(q_val=next_q, prev_q_val=s.q_val)
```
You can vectorize the update function to run across:
- Multiple initial values
- Multiple gamma or beta values
- Multiple MDP instances
Example for multiple gamma values using jax.vmap:
```python
State Initialization
initstate = State( qval=initqvals, prevqval=initqvals, gamma=jnp.array([0.9, 0.95, 0.99, 0.999]), beta=0.01, alpha=0.1 )
Iterations
finalstate, allstates = jax.lax.scan(
jax.vmap( # vmapped update function
lambda s, ix: (update(s, mdp, ix), s),
inaxes=(State(0, 0, 0, None, None), None)
outaxes=(State(0, 0, 0, None, None), 0)
),
init_state, # initial state
jnp.arange(100) # Number of iterations
)
```
MDPs
In jaxdp, MDPs are PyTrees and therefore compatible with JAX transformations.
```python import jax import jax.numpy as jnp from jaxdp.mdp.garnet import garnetmdp as makegarnet
n_mdp = 8 key = jax.random.PRNGKey(42)
List of random MDPs with different seeds
mdps = [makegarnet(statesize=300, actionsize=10, key=key, branchsize=4, minreward=-1, maxreward=1) for key in jax.random.split(key, n_mdp)]
Stacked MDP
stackedmdp = jax.treemap(lambda *mdps: jnp.stack(mdps), *mdps) ```
Once stacked, MDPs can be provided to vectorized functions:
```Python
mdps[0].transition.shape (10, 300, 300)
stacked_mdp.transition.shape (8, 10, 300, 300) ```
[!Warning] MDP components must have matching shapes for vectorization. Variable action or state sizes are not supported.
Installation
Requires Python 3.11+
bash
pip install -r requirements.txt
pip install -e .
Owner
- Name: Tolga Ok
- Login: TolgaOk
- Kind: user
- Website: tolgaok.github.io
- Repositories: 2
- Profile: https://github.com/TolgaOk
Citation (CITATION.cff)
abstract: Fast detection of repackaged Android applications based on the comparison of resource files included into the package.
version: 0.3.0
authors:
- family-names: Ok
given-names: Tolga
orcid: "https://orcid.org/0000-0002-3669-6121"
cff-version: 1.2.0
date-released: "2024-02-01"
identifiers:
- type: url
value: "https://github.com/zyrikby/FSquaDRA/tree/dc42c93991240da0fc9f1081e72be3eeb17d2638"
description: Latest version
keywords:
- "dynamic programming"
- "reinforcement learning"
- "jax"
license: MIT License
message: If you use this software, please cite it using these metadata.
repository-code: "https://github.com/TolgaOk/jaxdp"
title: A Dynamic Programming package for discrete MDPs in JAX
GitHub Events
Total
- Watch event: 2
- Delete event: 1
- Issue comment event: 2
- Push event: 14
- Pull request event: 10
- Create event: 5
Last Year
- Watch event: 2
- Delete event: 1
- Issue comment event: 2
- Push event: 14
- Pull request event: 10
- Create event: 5
Issues and Pull Requests
Last synced: 6 months ago
All Time
- Total issues: 0
- Total pull requests: 4
- Average time to close issues: N/A
- Average time to close pull requests: less than a minute
- Total issue authors: 0
- Total pull request authors: 1
- Average comments per issue: 0
- Average comments per pull request: 0.0
- Merged pull requests: 4
- Bot issues: 0
- Bot pull requests: 0
Past Year
- Issues: 0
- Pull requests: 4
- Average time to close issues: N/A
- Average time to close pull requests: less than a minute
- Issue authors: 0
- Pull request authors: 1
- Average comments per issue: 0
- Average comments per pull request: 0.0
- Merged pull requests: 4
- Bot issues: 0
- Bot pull requests: 0
Top Authors
Issue Authors
Pull Request Authors
- TolgaOk (6)
Top Labels
Issue Labels
Pull Request Labels
Dependencies
- distrax ==0.1.3
- jax ==0.4.13
- jaxtyping ==0.2.19