BordAX: A High-Performance JAX Framework for Programmatic Reinforcement Learning

BordAX: A High-Performance JAX Framework for Programmatic Reinforcement Learning - Published in JOSS (2026)

https://github.com/synthesislab/bordax

Science Score: 87.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
  • .zenodo.json file
  • DOI references
    Found 1 DOI reference(s) in JOSS metadata
  • Academic publication links
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
    Published in Journal of Open Source Software

Keywords

jax reinforcement-learning rl
Last synced: 9 days ago · JSON representation

Repository

A High-Performance JAX Framework for Programmatic Reinforcement Learning

Basic Info
  • Host: GitHub
  • Owner: SynthesisLab
  • License: mit
  • Language: Python
  • Default Branch: main
  • Homepage:
  • Size: 2.68 MB
Statistics
  • Stars: 3
  • Watchers: 2
  • Forks: 1
  • Open Issues: 1
  • Releases: 1
Topics
jax reinforcement-learning rl
Created over 1 year ago · Last pushed 17 days ago
Metadata Files
Readme Contributing License Code of conduct

README.md

BordAX

**A High-Performance JAX Framework for Programmatic Reinforcement Learning** [![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg)](https://www.python.org/downloads/) [![JAX](https://img.shields.io/badge/JAX-0.8.0+-orange.svg)](https://github.com/google/jax) [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) [![CI](https://github.com/SynthesisLab/bordax/actions/workflows/ci.yml/badge.svg)](https://github.com/SynthesisLab/bordax/actions/workflows/ci.yml) [![Coverage](https://codecov.io/gh/SynthesisLab/bordax/branch/main/graph/badge.svg)](https://codecov.io/gh/SynthesisLab/bordax) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![Docs](https://img.shields.io/badge/docs-mkdocs-blue.svg)](https://synthesislab.github.io/bordax)

Overview

BordAX is a research-focused framework for Programmatic Reinforcement Learning (PRL) that combines the speed of JAX with support for structured, interpretable policies including neural networks, boolean functions, and decision trees.

Key Features

  • High Performance — Fully JIT-compiled training pipelines leveraging JAX's XLA compilation
  • Modular Architecture — Clean separation between agents, algorithms, environments, and training
  • Multiple Policy Types — MLPs, boolean functions (HyperBool), and decision trees (DTSemNet)
  • Flexible Algorithms — Built-in PPO (on-policy) and DQN (off-policy) with easy extensibility
  • Environment Agnostic — Supports both Gymnax (JIT-compiled) and Gymnasium environments
  • Production Ready — Checkpointing, logging, WandB integration, and comprehensive tests

Performance

BordAX achieves high performance through:

  • Full JIT compilation for jittable environments (Gymnax)
  • Vectorized environments via jax.vmap
  • Efficient loops using jax.lax.scan
  • Pure functional design compatible with XLA optimization

JIT Compilation Strategy

| Environment | Algorithm | JIT Scope | |-------------|-----------|-----------| | Gymnax | On-policy | Entire train_step | | Gymnax | Off-policy | update only | | Gymnasium | Any | update only |

Benchmark: BordAX vs Stable-Baselines3

PPO on CartPole-v1 with identical hyperparameters (5 seeds, 51k timesteps):

| Framework | Training Time | Throughput | Speedup | |-----------|--------------|------------|---------| | BordAX + Gymnax (Full JIT) | 3.73s ± 0.06s | 13,709 steps/s | 2.8x | | BordAX + Gymnasium | 5.41s ± 0.11s | 9,468 steps/s | 1.9x | | Stable-Baselines3 | 10.40s ± 0.05s | 4,923 steps/s | 1.0x |

Benchmark Comparison

With Gymnax (fully JIT-compiled), BordAX is 2.8x faster than Stable-Baselines3. Even with Gymnasium (Python environment), BordAX is 1.9x faster.

Measured on Apple M3 Pro (2023) with JAX 0.9.0, Stable-Baselines3 2.8.0, PyTorch 2.11.0.

Run the benchmark yourself: bash uv sync --extra benchmark uv run python compare_sb3.py


Installation

```bash

Clone the repository

git clone https://github.com/SynthesisLab/bordax.git cd bordax

Install with uv (recommended)

uv sync

Or with pip

pip install -e .

With optional dependencies (WandB, visualization)

pip install -e ".[all]" ```

Verify Installation

bash python -c "from bordax.training.trainer import Trainer; print('BordAX installed successfully')"


Quick Start

Train PPO on CartPole

bash python train_ppo.py

  • Solves CartPole-v1 (reward = 500) in ~400k steps
  • Training time: ~18 seconds on CPU
  • Throughput: ~23,000 steps/s

PPO Training

Train DQN on CartPole

bash python train_dqn.py

  • Solves CartPole-v1 in ~50k steps
  • Training time: ~36 seconds on CPU
  • Includes 1,000 step warmup phase

DQN Training

Custom Training Script

```python import jax from bordax.training.trainer import Trainer, TrainerConfig from bordax.algorithms.utils import makealgo from bordax.environments.utils import makeenv from bordax.agents.utils import make_agent

Setup environments

env = makeenv("gymnax/CartPole-v1", {"initconfig": {}, "resetconfig": {}}, numenvs=4) evalenv = makeenv("gymnax/CartPole-v1", {"initconfig": {}, "resetconfig": {}}, num_envs=1)

Create agent with MLP policy and value networks

agent = makeagent("mlp/mlp", env, { "policylayers": [64, 64], "value_layers": [64, 64], })

Configure PPO algorithm

algorithm = makealgo("ppo", { "lr": 3e-4, "rolloutlength": 2048, "gamma": 0.99, "lambda": 0.95, "clipschedule": lambda : 0.2, "vfschedule": lambda : 0.5, "entschedule": lambda : 0.01, "numminibatches": 16, "numsgdsteps": 10, })

Setup trainer

config = TrainerConfig( numcheckpoints=100, epochspercheckpoint=1, evaluationepisodes=32, debug=True, )

trainer = Trainer(env, eval_env, agent, algorithm, config)

Train

key = jax.random.PRNGKey(0) initkey, trainkey = jax.random.split(key) trainer.init(initkey) evaldata = trainer.run(train_key) ```


Architecture

BordAX uses a modular pipeline architecture that cleanly separates concerns:

Trainer └─> Algorithm (Collector + BatchBuilder + Updater) ├─> Collector: Generates environment transitions ├─> BatchBuilder: Constructs training batches └─> Updater: Computes gradients and updates parameters

Core Components

| Component | Purpose | Examples | |-----------|---------|----------| | Agent | Defines policy and value networks | MLPPolicyValue, BooleanPolicyValue, DTPolicy, DQNAgent | | Algorithm | Bundles training pipeline components | ppo_algo(), dqn_algo() | | Collector | Generates transitions via env interaction | OnPolicyCollector, EpsGreedyCollector | | BatchBuilder | Transforms data into training batches | FullBufferBatch, MiniBatch, UniformReplayBatch | | Updater | Updates parameters using gradients | SGDUpdate, DQNUpdater | | Trainer | Orchestrates full training loop | Trainer |

Supported Algorithms

| Algorithm | Type | Collector | Batch Strategy | |-----------|------|-----------|----------------| | PPO | On-policy | OnPolicyCollector | FullBufferBatchMiniBatch | | DQN | Off-policy | EpsGreedyCollector | UniformReplayBatch |


Policy Representations

Standard Neural Networks

MLP Policy-Value (mlp/mlp): python agent = make_agent("mlp/mlp", env, { "policy_layers": [128, 128, 64], "value_layers": [128, 128, 64], })

Programmatic Policies

HyperBool — Boolean function-based policies (mlp/bool): python agent = make_agent("mlp/bool", env, { "n": 4, # Number of boolean variables "value_layers": [128, 64, 32], })

DTSemNet — Decision tree policies (mlp/dt): python agent = make_agent("mlp/dt", env, { "tree_depth": 4, "value_layers": [64, 64], })

DQN Agent

Q-Network (dqn/mlp): python agent = make_agent("dqn/mlp", env, { "q_layers": [64, 64], })


Project Structure

bordax/ ├── bordax/ │ ├── agents/ # Agent implementations │ │ ├── base.py # MLPPolicyValue, BooleanPolicyValue, DTPolicy, DQNAgent │ │ ├── components.py # Neural modules (MLP, DTSemNet, BooleanFunction) │ │ └── utils.py # make_agent() factory │ ├── algorithms/ # RL algorithms │ │ ├── base.py # Algorithm class, ppo_algo(), dqn_algo() │ │ ├── losses.py # PPOLoss, DQNLoss │ │ └── utils.py # make_algo() factory │ ├── data/ # Data collection and batching │ │ ├── collectors.py # OnPolicyCollector, EpsGreedyCollector │ │ ├── batchbuilders.py # Batch transformations │ │ └── buffer.py # ReplayBuffer │ ├── environments/ # Environment adapters │ │ └── utils.py # EnvAdapter, make_env() │ ├── training/ # Training infrastructure │ │ ├── trainer.py # Main Trainer class │ │ ├── evaluation.py # Evaluator │ │ ├── logging.py # Logger with WandB support │ │ ├── checkpointing.py # Model checkpointing (Orbax) │ │ └── updaters.py # SGDUpdate, DQNUpdater │ └── types.py # Core type definitions ├── tests/ # Test suite (48 tests, 77% coverage) │ ├── unit/ # Fast component tests │ ├── integration/ # Pipeline tests │ └── slow/ # Learning verification tests ├── train_ppo.py # PPO training example ├── train_dqn.py # DQN training example └── compare_sb3.py # Stable-Baselines3 benchmark


Testing

BordAX has a comprehensive test suite with 48 tests achieving 77% code coverage.

```bash

install test dependencies

uv sync --extra dev

run all tests (excluding slow)

uv run python -m pytest tests/ -m "not slow" -v

run slow learning tests

uv run python -m pytest tests/ -m slow -v

run with coverage

uv run python -m pytest tests/ --cov=bordax --cov-report=term-missing ```

Test Categories

| Category | Tests | Purpose | |----------|-------|---------| | Unit | 44 | Fast component tests | | Integration | 2 | Full pipeline verification | | Slow | 2 | Learning verification |


Dependencies

| Package | Version | Purpose | |---------|---------|---------| | JAX | >=0.8.0 | Core computation | | Flax | >=0.12.0 | Neural networks | | Optax | >=0.2.6 | Optimizers | | Gymnax | >=0.0.9 | JAX environments | | Gymnasium | >=1.2.0 | Standard environments | | Distrax | >=0.1.7 | Distributions | | Orbax | >=0.11.32 | Checkpointing |

Optional: WandB (experiment tracking), Matplotlib/Seaborn (visualization)


Restoring from Checkpoints

```bash

Restore last checkpoint and continue training

python train_ppo.py --restore-last ```


License

BordAX is released under the MIT License.

Support

For questions, bug reports, and feature requests, please open a GitHub Issue. See CONTRIBUTING.md for guidelines.


Acknowledgments

BordAX builds on excellent work from the JAX ecosystem:

  • JAX — High-performance numerical computing
  • Flax — Neural network library
  • Gymnax — JAX-compatible RL environments
  • Optax — Gradient processing and optimization
  • Distrax — Probability distributions

Built with JAX for speed and interpretability

Owner

  • Name: SynthesisLab
  • Login: SynthesisLab
  • Kind: organization
  • Location: France

Synthesis research team in LaBRI, Bordeaux, working on Program Synthesis, Reinforcement Learning, Specification Mining...

JOSS Publication

BordAX: A High-Performance JAX Framework for Programmatic Reinforcement Learning
Published
June 25, 2026
Volume 11, Issue 122, Page 10470
Authors
Roman Kniazev ORCID
CNRS, LaBRI, University of Bordeaux, France
Nathanaël Fijalkow ORCID
CNRS, LaBRI, University of Bordeaux, France
Editor
Wentao Ye ORCID
Tags
JAX reinforcement learning programmatic policies decision trees interpretable machine learning

GitHub Events

Total
  • Delete event: 3
  • Pull request event: 5
  • Fork event: 1
  • Watch event: 1
  • Issue comment event: 1
  • Push event: 12
  • Create event: 1
Last Year
  • Delete event: 3
  • Pull request event: 5
  • Fork event: 1
  • Watch event: 1
  • Issue comment event: 1
  • Push event: 12
  • Create event: 1

Issues and Pull Requests

Last synced: 2 months ago

All Time
  • Total issues: 0
  • Total pull requests: 0
  • Average time to close issues: N/A
  • Average time to close pull requests: N/A
  • Total issue authors: 0
  • Total pull request authors: 0
  • Average comments per issue: 0
  • Average comments per pull request: 0
  • Merged pull requests: 0
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 0
  • Pull requests: 0
  • Average time to close issues: N/A
  • Average time to close pull requests: N/A
  • Issue authors: 0
  • Pull request authors: 0
  • Average comments per issue: 0
  • Average comments per pull request: 0
  • Merged pull requests: 0
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
Pull Request Authors
Top Labels
Issue Labels
Pull Request Labels