https://github.com/kallel-mahdi/supersac
Science Score: 49.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
○CITATION.cff file
-
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
✓DOI references
Found 1 DOI reference(s) in README -
✓Academic publication links
Links to: zenodo.org -
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (11.8%) to scientific vocabulary
Repository
Basic Info
- Host: GitHub
- Owner: kallel-mahdi
- License: mit
- Language: Jupyter Notebook
- Default Branch: main
- Size: 38.6 MB
Statistics
- Stars: 2
- Watchers: 2
- Forks: 2
- Open Issues: 0
- Releases: 0
Metadata Files
README.md
A JAX Backbone for RL projects
This project serves as a "central backbone" for an RL codebase, designed to accelerate prototyping and diagnosis of new algorithms (although it auxiliarily does contain reference implementations of SAC, CQL, IQL, BC). It is inspired greatly by Ilya Kostrikov's JaxRL codebase.
The primary goal of the codebase is to make ease of coding up a new algorithm: towards this goal, the primary philosophy is that
algorithms should be single-file implementations
This means that (almost) all components of the algorithm (from update rule to network choices to hyperparameter choices) are all contained in one file (e.g. see BC example or SAC example). This makes it easy to read and understand the algorithm, and also makes it easy to modify the algorithm to test out new ideas. The code is also designed to scale as easily as possible to multi-GPU / TPU setups, with simple abstractions for distributed training.
Installation
Requires jax, flax, optax, distrax, and optionally wandb for logging. Clone this repository and install it (e.g. pip install -e .) or add to python path.
Usage
The fastest way to understand how to use this skeleton is to see the reference SAC implementation:
Structure
The code contains the following files:
- jaxrl_m.common: Contains the TrainState abstraction (a fork of Flax's TrainState class with some additional syntactic features for ease of use), and some other useful general utilities (
target_update,shard_batch) - jaxrl_m.dataset: Contains the Dataset class (which can store and sample from buffers containing arbitrarily nested dictionaries) and an equivalent ReplayBuffer class
- jaxrl_m.networks: Contains implementations of common RL networks (MLP, Critic, ValueCritic, Policy)
- jaxrl_m.evaluation: Contains code for running evaluation episodes of agents (e.g. with the
evaluate(policy, env)function) - jaxrl_m.wandb: Contains code for easily setting up Weights & Biases for experiments
- jaxrl_m.typing: Useful type aliases
- jaxrl_m.vision:
vision.modelscontains common vision models (e.g. ResNet, ResNetV2, Impala),vision.data_augmentationscontains common augmentations (e.g. random crop, random color jitter, gaussian blur)
Examples
Example implementations:
Example Launchers:
Citation
If you use this codebase in an academic work, please cite
@software{jaxrl_minimal,
author = {Dibya Ghosh},
title = {dibyaghosh/jaxrl\_m},
month = April,
year = 2023,
publisher = {Zenodo},
version = {v0.1},
doi = {10.5281/zenodo.7958265},
url = {https://github.com/dibyaghosh/jaxrl_m}
}
Owner
- Login: kallel-mahdi
- Kind: user
- Repositories: 3
- Profile: https://github.com/kallel-mahdi
GitHub Events
Total
- Watch event: 1
- Member event: 2
- Push event: 77
- Fork event: 2
- Create event: 2
Last Year
- Watch event: 1
- Member event: 2
- Push event: 77
- Fork event: 2
- Create event: 2
Dependencies
- distrax *
- flax *
- gymnasium *
- ml_collections *
- optax *
- tqdm *
- wandb *