Science Score: 67.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
    Found 2 DOI reference(s) in README
  • Academic publication links
    Links to: arxiv.org, zenodo.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (16.2%) to scientific vocabulary
Last synced: 6 months ago · JSON representation ·

Repository

Basic Info
  • Host: GitHub
  • Owner: charles-zhng
  • License: apache-2.0
  • Language: Python
  • Default Branch: main
  • Size: 12.4 MB
Statistics
  • Stars: 0
  • Watchers: 1
  • Forks: 0
  • Open Issues: 0
  • Releases: 0
Created almost 2 years ago · Last pushed almost 2 years ago
Metadata Files
Readme Contributing License Citation

README.md

Threw together this to try vmpo with rodent imitation learning

  • first working run with this: python3 ff_vmpo_continuous.py env=brax/tracking logger.checkpointing.save_model=True arch.total_num_envs=256 arch.num_eval_episodes=16
  • the env here is old--some bugs were fixed since

    Stoix logo

Python Versions License Code Style MyPy DOI

Distributed Single-Agent Reinforcement Learning End-to-End in JAX

**_stoic - a person who can endure pain or hardship without showing their feelings or complaining._**

Welcome to Stoix! 🏛️

Stoix provides simplified code for quickly iterating on ideas in single-agent reinforcement learning with useful implementations of popular single-agent RL algorithms in JAX allowing for easy parallelisation across devices with JAX's pmap. All implementations are fully compilable with JAX's jit thus making training and environment execution very fast. However, this requires environments written in JAX. Algorithms and their default hyperparameters have not been hyper-optimised for any specific environment and are useful as a starting point for research and/or for initial baselines.

To join us in these efforts, please feel free to reach out, raise issues or read our contribution guidelines (or just star 🌟 to stay up to date with the latest developments)!

Stoix is fully in JAX with substantial speed improvement compared to other popular libraries. We currently provide native support for the Jumanji environment API and wrappers for popular JAX-based RL environments.

Code Philosophy 🧘

The current code in Stoix was initially largely taken and subsequently adapted from Mava. As Mava develops, Stoix will hopefully adopt their optimisations that are relevant for single-agent RL. Like Mava, Stoix is not designed to be a highly modular library and is not meant to be imported. Our repository focuses on simplicity and clarity in its implementations while utilising the advantages offered by JAX such as pmap and vmap, making it an excellent resource for researchers and practitioners to build upon. Stoix follows a similar design philosophy to CleanRL and PureJaxRL, where we allow for some code duplication to enable readability, easy reuse, and fast adaptation. A notable difference between Stoix and other single-file libraries is that Stoix makes use of abstraction where relevant. It is not intended to be purely educational with research utility as the primary focus. In particular, abstraction is currently used for network architectures, environments, logging, and evaluation.

Overview 🦜

Stoix currently offers the following building blocks for Single-Agent RL research:

Implementations of Algorithms 🥑

  • Deep Q-Network (DQN) - Paper
  • Double DQN (DDQN) - Paper
  • Dueling DQN - Paper
  • Categorical DQN (C51) - Paper
  • Munchausen DQN (M-DQN) Paper
  • Quantile Regression DQN (QR-DQN) - Paper
  • DQN with Regularized Q-learning (DQN-Reg) Paper
  • REINFORCE With Baseline - Paper
  • Deep Deterministic Policy Gradient (DDPG) - Paper
  • Twin Delayed DDPG (TD3) - Paper
  • Distributed Distributional DDPG (D4PG) - Paper
  • Soft Actor-Critic (SAC) - Paper
  • Proximal Policy Optimization (PPO) - Paper
  • Discovered Policy Optimization (DPO) Paper
  • Maximum a Posteriori Policy Optimisation (MPO) - Paper
  • On-Policy Maximum a Posteriori Policy Optimisation (V-MPO) - Paper
  • Advantage-Weighted Regression (AWR) - Paper
  • AlphaZero - Paper
  • MuZero - Paper
  • Sampled Alpha/Mu-Zero - Paper

Environment Wrappers 🍬

Stoix offers wrappers for Gymnax, Jumanji, Brax, XMinigrid, Craftax and even JAXMarl (although using Centralised Controllers).

Statistically Robust Evaluation 🧪

Stoix natively supports logging to json files which adhere to the standard suggested by Gorsane et al. (2022). This enables easy downstream experiment plotting and aggregation using the tools found in the MARL-eval library.

Performance and Speed 🚀

As the code in Stoix (at the time of creation) was in essence a port of Mava, for further speed comparisons we point to their repo. Additionally, we refer to the PureJaxRL blog post here where the speed benefits of end-to-end JAX systems are discussed.

Below we provide some plots illustrating that Stoix performs equally to that of PureJaxRL but with the added benefit of the code being already set up for pmap distribution over devices as well as the other features provided (algorithm implementations, logging, config system, etc).

ppo dqn

I've also included a plot of the training time for 5e5 steps of PPO as one scales the number of environments. PureJaxRL does not pmap and thus runs on a single a device.

env_scaling

Lastly, please keep in mind for practical use that current networks and hyperparameters for algorithms have not been tuned.

Installation 🎬

At the moment Stoix is not meant to be installed as a library, but rather to be used as a research tool.

You can use Stoix by cloning the repo and pip installing as follows:

bash git clone https://github.com/EdanToledo/Stoix.git cd Stoix pip install -e .

We have tested Stoix on Python 3.10. Note that because the installation of JAX differs depending on your hardware accelerator, we advise users to explicitly install the correct JAX version (see the official installation guide).

Quickstart ⚡

To get started with training your first Stoix system, simply run one of the system files. e.g.,

bash python stoix/systems/ppo/ff_ppo.py

Stoix makes use of Hydra for config management. In order to see our default system configs please see the stoix/configs/ directory. A benefit of Hydra is that configs can either be set in config yaml files or overwritten from the terminal on the fly. For an example of running a system on the CartPole environment, the above code can simply be adapted as follows:

bash python stoix/systems/ppo/ff_ppo.py env=gymnax/cartpole

Contributing 🤝

Please read our contributing docs for details on how to submit pull requests, our Contributor License Agreement and community guidelines.

Roadmap 🛤️

We plan to iteratively expand Stoix in the following increments:

  • 🌴 Support for more environments as they become available.
  • 🔁 More robust recurrent systems.
    • [ ] Add recurrent variants of all systems
    • [ ] Allow easy interchangability of recurrent cells/architecture via config
  • 📊 Benchmarks on more environments.
    • [ ] Create leaderboard of algorithms
  • 🦾 More algorithm implementations:
  • 🎮 Self-play 2-player Systems for board games.

Please do follow along as we develop this next phase!

Citing Stoix 📚

If you use Stoix in your work, please cite us:

bibtex @software{toledo2024stoix, author = {Toledo, Edan}, doi = {10.5281/zenodo.10916258}, month = apr, title = {{Stoix: Distributed Single-Agent Reinforcement Learning End-to-End in JAX}}, url = {https://github.com/EdanToledo/Stoix}, version = {v0.0.1}, year = {2024} }

Acknowledgements 🙏

We would like to thank the authors and developers of Mava as this was essentially a port of their repo at the time of creation. This helped set up a lot of the infracstructure of logging, evaluation and other utilities.

See Also 🔎

Related JAX Libraries In particular, we suggest users check out the following repositories:

  • 🦁 Mava: Distributed Multi-Agent Reinforcement Learning in JAX.
  • 🔌 OG-MARL: datasets with baselines for offline MARL in JAX.
  • 🌴 Jumanji: a diverse suite of scalable reinforcement learning environments in JAX.
  • 😎 Matrax: a collection of matrix games in JAX.
  • 🔦 Flashbax: accelerated replay buffers in JAX.
  • 📈 MARL-eval: standardised experiment data aggregation and visualisation for MARL.
  • 🦊 JaxMARL: accelerated MARL environments with baselines in JAX.
  • 🌀 DeepMind Anakin for the Anakin podracer architecture to train RL agents at scale.
  • ♟️ Pgx: JAX implementations of classic board games, such as Chess, Go and Shogi.
  • 🔼 Minimax: JAX implementations of autocurricula baselines for RL.

Disclaimer: This is not an official InstaDeep product nor is any of the work putforward associated with InstaDeep in any official capacity.

Owner

  • Name: Charles Zhang
  • Login: charles-zhng
  • Kind: user

Citation (CITATION.cff)

cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- family-names: "Toledo"
  given-names: "Edan"
title: "Stoix: Distributed Single-Agent Reinforcement Learning End-to-End in JAX"
version: v0.0.1
date-released: 2024-04-04
doi: 10.5281/zenodo.10916258
url: "https://github.com/EdanToledo/Stoix"

GitHub Events

Total
Last Year

Dependencies

.github/workflows/test_linters.yaml actions
  • actions/checkout v3 composite
  • actions/setup-python v4 composite
Dockerfile docker
  • nvidia/cuda 11.8.0-cudnn8-runtime-ubuntu22.04 build
pyproject.toml pypi
requirements/requirements-dev.txt pypi
  • black ==24.3.0 development
  • coverage * development
  • flake8 ==6.1.0 development
  • importlib-metadata <5.0 development
  • isort ==5.11.5 development
  • livereload * development
  • mkdocs ==1.2.3 development
  • mkdocs-git-revision-date-plugin * development
  • mkdocs-include-markdown-plugin * development
  • mkdocs-material ==8.2.7 development
  • mkdocs-mermaid2-plugin ==0.6.0 development
  • mkdocstrings ==0.18.0 development
  • mknotebooks ==0.7.1 development
  • mypy ==0.991 development
  • nbmake * development
  • pre-commit ==3.3.3 development
  • promise * development
  • pymdown-extensions * development
  • pytest ==7.0.1 development
  • pytest-cov * development
  • pytest-mock * development
  • pytest-parallel * development
  • pytest-xdist * development
  • pytype * development
  • testfixtures * development
requirements/requirements.txt pypi
  • brax >=0.9.0
  • chex *
  • colorama *
  • craftax *
  • flax *
  • gymnax >=0.0.6
  • huggingface_hub *
  • hydra-core ==1.3.2
  • jax >=0.4.10
  • jaxlib *
  • jaxmarl *
  • jumanji ==1.0.0
  • mctx *
  • neptune *
  • numpy *
  • omegaconf *
  • pgx *
  • protobuf ==3.20.2
  • rlax *
  • tdqm *
  • tensorboard_logger *
  • tensorflow_probability *