jaxdp

A Dynamic Programming package for discrete MDPs implemented in JAX

https://github.com/tolgaok/jaxdp

Science Score: 54.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (9.9%) to scientific vocabulary

Keywords

dynamic-programming jax markov-decision-processes reinforcement-learning
Last synced: 6 months ago · JSON representation ·

Repository

A Dynamic Programming package for discrete MDPs implemented in JAX

Basic Info
  • Host: GitHub
  • Owner: TolgaOk
  • License: mit
  • Language: Python
  • Default Branch: master
  • Homepage:
  • Size: 473 KB
Statistics
  • Stars: 7
  • Watchers: 1
  • Forks: 1
  • Open Issues: 0
  • Releases: 1
Topics
dynamic-programming jax markov-decision-processes reinforcement-learning
Created over 2 years ago · Last pushed 7 months ago
Metadata Files
Readme Contributing Citation

README.md

jaxdp

jaxdp is a Python package providing functional implementations of dynamic programming (DP) algorithms for finite state-action Markov decision processes (MDPs) within the logo ecosystem. By leveraging JAX transformations, you can accelerate DP algorithms (including GPU acceleration) through vectorized execution across multiple MDP instances, initial values, and parameters.

Vectorization

jaxdp functions are fully compatible with JAX transformations. They are stateless with memory explicitly provided to functions.

Algorithm Example

The examples directory contains implementations and benchmarks of planning algorithms using jaxdp. Below is a code snippet for Momentum accelerated Value Iteration:

```python """ Momentum accelerated Value Iteration. """ @struct.dataclass class State: qval: jnp.ndarray prevq_val: jnp.ndarray gamma: jnp.ndarray beta: jnp.ndarray alpha: jnp.ndarray

def update(s: State, mdp: MDP, step: int) -> State: diff = s.qval - s.prevqval bresidual = jaxdp.bellmanoptimalityoperator.q(mdp, s.qval, s.gamma) - s.qval nextq = s.qval + s.alpha * b_residual + s.beta * diff

return s.replace(q_val=next_q, prev_q_val=s.q_val)

```

You can vectorize the update function to run across:

  • Multiple initial values
  • Multiple gamma or beta values
  • Multiple MDP instances

Example for multiple gamma values using jax.vmap:

```python

State Initialization

initstate = State( qval=initqvals, prevqval=initqvals, gamma=jnp.array([0.9, 0.95, 0.99, 0.999]), beta=0.01, alpha=0.1 )

Iterations

finalstate, allstates = jax.lax.scan( jax.vmap( # vmapped update function lambda s, ix: (update(s, mdp, ix), s), inaxes=(State(0, 0, 0, None, None), None) outaxes=(State(0, 0, 0, None, None), 0) ),
init_state, # initial state jnp.arange(100) # Number of iterations ) ```

MDPs

In jaxdp, MDPs are PyTrees and therefore compatible with JAX transformations.

```python import jax import jax.numpy as jnp from jaxdp.mdp.garnet import garnetmdp as makegarnet

n_mdp = 8 key = jax.random.PRNGKey(42)

List of random MDPs with different seeds

mdps = [makegarnet(statesize=300, actionsize=10, key=key, branchsize=4, minreward=-1, maxreward=1) for key in jax.random.split(key, n_mdp)]

Stacked MDP

stackedmdp = jax.treemap(lambda *mdps: jnp.stack(mdps), *mdps) ```

Once stacked, MDPs can be provided to vectorized functions:

```Python

mdps[0].transition.shape (10, 300, 300)

stacked_mdp.transition.shape (8, 10, 300, 300) ```

[!Warning] MDP components must have matching shapes for vectorization. Variable action or state sizes are not supported.

Installation

Requires Python 3.11+

bash pip install -r requirements.txt pip install -e .

Owner

  • Name: Tolga Ok
  • Login: TolgaOk
  • Kind: user

Citation (CITATION.cff)

abstract: Fast detection of repackaged Android applications based on the comparison of resource files included into the package.
version: 0.3.0
authors:
  - family-names: Ok
    given-names: Tolga
    orcid: "https://orcid.org/0000-0002-3669-6121"
cff-version: 1.2.0
date-released: "2024-02-01"
identifiers:
  - type: url
    value: "https://github.com/zyrikby/FSquaDRA/tree/dc42c93991240da0fc9f1081e72be3eeb17d2638"
    description: Latest version
keywords:
  - "dynamic programming"
  - "reinforcement learning"
  - "jax"
license: MIT License
message: If you use this software, please cite it using these metadata.
repository-code: "https://github.com/TolgaOk/jaxdp"
title: A Dynamic Programming package for discrete MDPs in JAX

GitHub Events

Total
  • Watch event: 2
  • Delete event: 1
  • Issue comment event: 2
  • Push event: 14
  • Pull request event: 10
  • Create event: 5
Last Year
  • Watch event: 2
  • Delete event: 1
  • Issue comment event: 2
  • Push event: 14
  • Pull request event: 10
  • Create event: 5

Issues and Pull Requests

Last synced: 6 months ago

All Time
  • Total issues: 0
  • Total pull requests: 4
  • Average time to close issues: N/A
  • Average time to close pull requests: less than a minute
  • Total issue authors: 0
  • Total pull request authors: 1
  • Average comments per issue: 0
  • Average comments per pull request: 0.0
  • Merged pull requests: 4
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 0
  • Pull requests: 4
  • Average time to close issues: N/A
  • Average time to close pull requests: less than a minute
  • Issue authors: 0
  • Pull request authors: 1
  • Average comments per issue: 0
  • Average comments per pull request: 0.0
  • Merged pull requests: 4
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
Pull Request Authors
  • TolgaOk (6)
Top Labels
Issue Labels
Pull Request Labels

Dependencies

pyproject.toml pypi
requirements.txt pypi
  • distrax ==0.1.3
  • jax ==0.4.13
  • jaxtyping ==0.2.19
setup.py pypi