sbx-rl

SBX: Stable Baselines Jax (SB3 + Jax) RL algorithms

https://github.com/araffin/sbx

Science Score: 64.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Committers with academic emails
    1 of 5 committers (20.0%) from academic institutions
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (11.2%) to scientific vocabulary
Last synced: 6 months ago · JSON representation ·

Repository

SBX: Stable Baselines Jax (SB3 + Jax) RL algorithms

Basic Info
  • Host: GitHub
  • Owner: araffin
  • License: mit
  • Language: Python
  • Default Branch: master
  • Homepage:
  • Size: 277 KB
Statistics
  • Stars: 497
  • Watchers: 17
  • Forks: 49
  • Open Issues: 19
  • Releases: 15
Created over 3 years ago · Last pushed 6 months ago
Metadata Files
Readme Contributing License Code of conduct Citation Notice

README.md

CI codestyle

Stable Baselines Jax (SB3 + Jax = SBX)

Proof of concept version of Stable-Baselines3 in Jax.

Implemented algorithms: - Soft Actor-Critic (SAC) and SAC-N - Truncated Quantile Critics (TQC) - Dropout Q-Functions for Doubly Efficient Reinforcement Learning (DroQ) - Proximal Policy Optimization (PPO) - Deep Q Network (DQN) - Twin Delayed DDPG (TD3) - Deep Deterministic Policy Gradient (DDPG) - Batch Normalization in Deep Reinforcement Learning (CrossQ) - Simplicity Bias for Scaling Up Parameters in Deep Reinforcement Learning (SimBa)

Note: parameter resets for off-policy algorithms can be activated by passing a list of timesteps to the model constructor (ex: param_resets=[int(1e5), int(5e5)] to reset parameters and optimizers after 100000 and 500000 timesteps.

Install using pip

For the latest master version: pip install git+https://github.com/araffin/sbx or: pip install sbx-rl

Example

```python import gymnasium as gym

from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ

env = gym.make("Pendulum-v1", render_mode="human")

model = TQC("MlpPolicy", env, verbose=1) model.learn(totaltimesteps=10000, progress_bar=True)

vecenv = model.getenv() obs = vecenv.reset() for _ in range(1000): vecenv.render() action, states = model.predict(obs, deterministic=True) obs, reward, done, info = vecenv.step(action)

vec_env.close() ```

Using SBX with the RL Zoo

Since SBX shares the SB3 API, it is compatible with the RL Zoo, you just need to override the algorithm mapping:

```python import rlzoo3 import rlzoo3.train from rl_zoo3.train import train from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ

rlzoo3.ALGOS["ddpg"] = DDPG rlzoo3.ALGOS["dqn"] = DQN

See note below to use DroQ configuration

rl_zoo3.ALGOS["droq"] = DroQ

rlzoo3.ALGOS["sac"] = SAC rlzoo3.ALGOS["ppo"] = PPO rlzoo3.ALGOS["td3"] = TD3 rlzoo3.ALGOS["tqc"] = TQC rlzoo3.ALGOS["crossq"] = CrossQ rlzoo3.train.ALGOS = rlzoo3.ALGOS rlzoo3.expmanager.ALGOS = rlzoo3.ALGOS

if name == "main": train() ```

Then you can run this script as you would with the RL Zoo:

python train.py --algo sac --env HalfCheetah-v4 -params train_freq:4 gradient_steps:4 -P

The same goes for the enjoy script:

```python import rlzoo3 import rlzoo3.enjoy from rl_zoo3.enjoy import enjoy from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ

rlzoo3.ALGOS["ddpg"] = DDPG rlzoo3.ALGOS["dqn"] = DQN

See note below to use DroQ configuration

rl_zoo3.ALGOS["droq"] = DroQ

rlzoo3.ALGOS["sac"] = SAC rlzoo3.ALGOS["ppo"] = PPO rlzoo3.ALGOS["td3"] = TD3 rlzoo3.ALGOS["tqc"] = TQC rlzoo3.ALGOS["crossq"] = CrossQ rlzoo3.enjoy.ALGOS = rlzoo3.ALGOS rlzoo3.expmanager.ALGOS = rlzoo3.ALGOS

if name == "main": enjoy() ```

Note about DroQ

DroQ is a special configuration of SAC.

To have the algorithm with the hyperparameters from the paper, you should use (using RL Zoo config format): yaml HalfCheetah-v4: n_timesteps: !!float 1e6 policy: 'MlpPolicy' learning_starts: 10000 gradient_steps: 20 policy_delay: 20 policy_kwargs: "dict(dropout_rate=0.01, layer_norm=True)"

and then using the RL Zoo script defined above: python train.py --algo sac --env HalfCheetah-v4 -c droq.yml -P.

We recommend playing with the policy_delay and gradient_steps parameters for better speed/efficiency. Having a higher learning rate for the q-value function is also helpful: qf_learning_rate: !!float 1e-3.

Note: when using the DroQ configuration with CrossQ, you should set layer_norm=False as there is already batch normalization.

Note about SimBa

SimBa is a special network architecture for off-policy algorithms (SAC, TQC, ...).

Some recommended hyperparameters (tested on MuJoCo and PyBullet environments): ```python import optax

defaulthyperparams = dict( nenvs=1, ntimesteps=int(1e6), policy="SimbaPolicy", learningrate=3e-4, # qflearningrate=1e-3, policykwargs={ "optimizerclass": optax.adamw, # "optimizerkwargs": {"weightdecay": 0.01}, # Note: here [128] represent a residual block, not just a single layer "netarch": {"pi": [128], "qf": [256, 256]}, "ncritics": 2, }, learningstarts=10000, # Important: input normalization using VecNormalize normalize={"normobs": True, "normreward": False}, )

hyperparams = {}

You can also loop gym.registry

for envid in [ "HalfCheetah-v4", "HalfCheetahBulletEnv-v0", "Ant-v4", ]: hyperparams[envid] = default_hyperparams ```

and then using the RL Zoo script defined above: python train.py --algo tqc --env HalfCheetah-v4 -c simba.py -P.

Benchmark

A partial benchmark can be found on OpenRL Benchmark where you can also find several reports.

Citing the Project

To cite this repository in publications:

bibtex @article{stable-baselines3, author = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann}, title = {Stable-Baselines3: Reliable Reinforcement Learning Implementations}, journal = {Journal of Machine Learning Research}, year = {2021}, volume = {22}, number = {268}, pages = {1-8}, url = {http://jmlr.org/papers/v22/20-1364.html} }

Maintainers

Stable-Baselines3 is currently maintained by Ashley Hill (aka @hill-a), Antonin Raffin (aka @araffin), Maximilian Ernestus (aka @ernestum), Adam Gleave (@AdamGleave), Anssi Kanervisto (@Miffyli) and Quentin Gallouédec (@qgallouedec).

Important Note: We do not do technical support, nor consulting and don't answer personal questions per email. Please post your question on the RL Discord, Reddit or Stack Overflow in that case.

How To Contribute

To any interested in making the baselines better, there is still some documentation that needs to be done. If you want to contribute, please read CONTRIBUTING.md guide first.

Contributors

We would like to thank our contributors: @jan1854.

Owner

  • Name: Antonin RAFFIN
  • Login: araffin
  • Kind: user
  • Location: Munich
  • Company: @DLR-RM

Research Engineer in Robotics and Machine Learning, with a focus on Reinforcement Learning.

Citation (CITATION.bib)

@article{stable-baselines3,
  author  = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann},
  title   = {Stable-Baselines3: Reliable Reinforcement Learning Implementations},
  journal = {Journal of Machine Learning Research},
  year    = {2021},
  volume  = {22},
  number  = {268},
  pages   = {1-8},
  url     = {http://jmlr.org/papers/v22/20-1364.html}
}

GitHub Events

Total
  • Create event: 13
  • Release event: 4
  • Issues event: 16
  • Watch event: 134
  • Delete event: 3
  • Issue comment event: 53
  • Push event: 48
  • Pull request review comment event: 20
  • Pull request review event: 22
  • Pull request event: 16
  • Fork event: 15
Last Year
  • Create event: 13
  • Release event: 4
  • Issues event: 16
  • Watch event: 134
  • Delete event: 3
  • Issue comment event: 53
  • Push event: 48
  • Pull request review comment event: 20
  • Pull request review event: 22
  • Pull request event: 16
  • Fork event: 15

Committers

Last synced: 12 months ago

All Time
  • Total Commits: 65
  • Total Committers: 5
  • Avg Commits per committer: 13.0
  • Development Distribution Score (DDS): 0.092
Past Year
  • Commits: 10
  • Committers: 5
  • Avg Commits per committer: 2.0
  • Development Distribution Score (DDS): 0.4
Top Committers
Name Email Commits
Antonin Raffin a****n@e****g 59
Jan Schneider 3****4 3
Théo VINCENT 4****t 1
Paolo p****9@g****m 1
JamesHeald j****d@u****k 1
Committer Domains (Top 20 + Academic)

Issues and Pull Requests

Last synced: 6 months ago

All Time
  • Total issues: 36
  • Total pull requests: 36
  • Average time to close issues: 2 months
  • Average time to close pull requests: 14 days
  • Total issue authors: 32
  • Total pull request authors: 10
  • Average comments per issue: 3.0
  • Average comments per pull request: 1.72
  • Merged pull requests: 26
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 12
  • Pull requests: 12
  • Average time to close issues: 6 days
  • Average time to close pull requests: 28 days
  • Issue authors: 12
  • Pull request authors: 4
  • Average comments per issue: 2.83
  • Average comments per pull request: 2.17
  • Merged pull requests: 6
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • edmund735 (2)
  • jamesheald (2)
  • Robokan (2)
  • Deepakgthomas (2)
  • araffin (1)
  • theovincent (1)
  • LabChameleon (1)
  • thomashirtz (1)
  • bascat139 (1)
  • tobiasmerkt (1)
  • Raa23 (1)
  • LennertEvens (1)
  • alexpalms (1)
  • ZzzihaoGuo (1)
  • joaogui1 (1)
Pull Request Authors
  • araffin (30)
  • jan1854 (6)
  • naumix (3)
  • theovincent (2)
  • danielpalen (2)
  • tonyspumoni (2)
  • paolodelia99 (1)
  • corentinlger (1)
  • ClintonOffor (1)
  • jamesheald (1)
  • suijth (1)
  • joaogui1 (1)
Top Labels
Issue Labels
question (18) enhancement (9) bug (8) help wanted (4) documentation (1) good first issue (1)
Pull Request Labels
enhancement (1)

Packages

  • Total packages: 2
  • Total downloads:
    • pypi 915 last-month
  • Total dependent packages: 0
    (may contain duplicates)
  • Total dependent repositories: 1
    (may contain duplicates)
  • Total versions: 40
  • Total maintainers: 1
proxy.golang.org: github.com/araffin/sbx
  • Versions: 15
  • Dependent Packages: 0
  • Dependent Repositories: 0
Rankings
Dependent packages count: 5.5%
Average: 5.7%
Dependent repos count: 5.9%
Last synced: 6 months ago
pypi.org: sbx-rl

Jax version of Stable Baselines, implementations of reinforcement learning algorithms.

  • Versions: 25
  • Dependent Packages: 0
  • Dependent Repositories: 1
  • Downloads: 915 Last month
Rankings
Stargazers count: 5.0%
Forks count: 9.1%
Dependent packages count: 10.1%
Downloads: 11.0%
Average: 11.4%
Dependent repos count: 21.6%
Maintainers (1)
Last synced: 6 months ago

Dependencies

setup.py pypi
  • flax *
  • jax *
  • jaxlib *
  • optax *
  • rich *
  • stable_baselines3 *
  • tensorflow_probability *
  • tqdm *
.github/workflows/ci.yml actions
  • actions/checkout v2 composite
  • actions/setup-python v2 composite
pyproject.toml pypi