https://github.com/amazon-science/replay-based-recurrent-rl
Code for "Task-Agnostic Continual RL: In Praise of a Simple Baseline"
Science Score: 36.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
○CITATION.cff file
-
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
✓Academic publication links
Links to: arxiv.org -
○Committers with academic emails
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (9.9%) to scientific vocabulary
Keywords
Repository
Code for "Task-Agnostic Continual RL: In Praise of a Simple Baseline"
Basic Info
Statistics
- Stars: 34
- Watchers: 2
- Forks: 6
- Open Issues: 0
- Releases: 0
Topics
Metadata Files
README.md
Task-Agnostic Continual RL: In Praise of a Simple Baseline
Table of Contents
About The Project
Official codebase for the paper Task-Agnostic Continual Reinforcement Learning: Gaining Insights and Overcoming Challenges. The code can be useful to run continual RL (or multi-task RL) experiments in Meta-World (e.g. in Continual-World) as well as large-scale study in the synthetic benchmark Quadratic Optimization. The baselines, including replay-based recurrent RL (3RL) and ER-TX, are modular enough to be ported into another codebase as well.
If you have found the paper or codebase useful, consider citing the work:
@article{caccia2022task,
title={Task-Agnostic Continual Reinforcement Learning: Gaining Insights and Overcoming Challenges},
author={Caccia, Massimo and Mueller, Jonas and Kim, Taesup and Charlin, Laurent and Fakoor, Rasool},
journal={arXiv preprint arXiv:2205.14495},
year={2022}
}
Built With
Structure
code
algs/SAC
sac.py # training SAC
configs
hparams # hyperparameter config files
methods # methods' config files
settings # settings config files
misc
buffer.py # sampling data from buffer
runner_offpolicy.py # agent sampling data from the env
sequoia_envs.py # creating the environemnts
utils.py
models
networks.py # creates the neural networks
scripts
test_codebase.py # makes sures the repo runs correctly
train_and_eval
eval.py # evaluation logig
train_cl.py # training logic in CL
train_mtl.py # training logic in MTL
main.py # main file for CRL experiments
public_ck
... # checkpoints in CW10 benchmark
Installation
Essentially what you need is - python (3.8) - sequoia - mujoco_py - pytorch
It can be quite tricky to install mujoco_py, as well as running Meta-World. For this reason, we've used Sequoia to create the continual reinforcement learning environments.
Here's how you can install the dependencies in MacOS (BigSur)
create a env, ideally w/ conda
bash conda create -n tacrl python=3.8 conda activate tacrlinstall Sequoia w/ Meta-World add-on
bash pip install "sequoia[metaworld] @ git+https://www.github.com/lebrice/Sequoia.git@pass_seed_to_metaworld_envs"extra requirements
bash pip install -r requirements.txtinstall mujoco + mujoco key You will need to install MuJoCo (version >= 2.1)
UPDATE: I haven't reinstalled Mujoco since DeepMind acquisition and refactoring. Best of luck w/ the installation.
You can follow RoboSuite installation if you stumble on some GCC related bugs (MacOS specific).
For GCC / GL/glew.h related errors, you can use the instructions here
Contact us if you have any problem!
Usage
example of running SAC w/ an RNN in CW10
python
python code/main.py --train_mode cl --context_rnn True --setting_config CW10 --lr 0.0003 --batch_size 1028
or w/ a transformer in Quadratic Optimization in a multi-task regime
python
python code/main.py --train_mode mtl --context_transformer True --env_name Quadratic_opt --lr 0.0003 --batch_size 128
You can pass config files and reproduce the paper's results by combining a setting, method and hyperparameters config file in the following manner
python
python code/main.py --train_mode <cl, mtl> --setting_config <setting> --method_config <method> --hparam_config <hyperparameters>
e.g. running 3RL in CW10 w/ the hyperparameter prescribed by Meta-World (for Meta-world v2):
python
python code/main.py --train_mode cl --setting_config CW10 --method_config 3RL --hparam_config meta_world
For the MTRL experiments, run
python
python code/main.py --train_mode mtl
for prototyping, you can use the ant_direction environment:
python
python code/main.py --env_name Ant_direction-v3
Note: If you get an error about "yaml.safe_load", replace it with "yaml.load()".
Paper reproduction
For access to the WandB project w/ the results, please contact me.
For Figure 1, use analyse_reps.py (the models are in public_ck)
For all synthetic data experiments, you can create a wandb sweep
bash
wandb sweep --project Quadratic_opt code/configs/sweeps/Quadratic_opt.yaml
And then launch the wandb agent
bash
wandb agent <sweep_id>
For Figure 5
python
python main.py --train_mode cl --setting_config <CW10, CL_MW20_500k> --method_config <method> --hparam_config meta_world
For Figure 7
python
python main.py --train_mode mtl --setting_config MTL_MW20 --method_config <method> --hparam_config meta_world-noet
License
This project is licensed under the Apache-2.0 License.
Contact
Please open an issue on issues tracker to report problems or to ask questions or send an email to Massimo Caccia - @MassCaccia and Rasool Fakoor.
Owner
- Name: Amazon Science
- Login: amazon-science
- Kind: organization
- Website: https://amazon.science
- Twitter: AmazonScience
- Repositories: 80
- Profile: https://github.com/amazon-science
GitHub Events
Total
- Watch event: 3
Last Year
- Watch event: 3
Committers
Last synced: 8 months ago
Top Committers
| Name | Commits | |
|---|---|---|
| Massimo Caccia | m****a@g****m | 4 |
| rasoolfa | r****r@g****m | 4 |
| rasoolfakoor | 7****r | 3 |
| Fabrice Normandin | n****f@m****c | 1 |
| Amazon GitHub Automation | 5****o | 1 |
Committer Domains (Top 20 + Academic)
Issues and Pull Requests
Last synced: 8 months ago
All Time
- Total issues: 1
- Total pull requests: 5
- Average time to close issues: about 1 month
- Average time to close pull requests: 3 days
- Total issue authors: 1
- Total pull request authors: 3
- Average comments per issue: 7.0
- Average comments per pull request: 0.4
- Merged pull requests: 4
- Bot issues: 0
- Bot pull requests: 1
Past Year
- Issues: 0
- Pull requests: 0
- Average time to close issues: N/A
- Average time to close pull requests: N/A
- Issue authors: 0
- Pull request authors: 0
- Average comments per issue: 0
- Average comments per pull request: 0
- Merged pull requests: 0
- Bot issues: 0
- Bot pull requests: 0
Top Authors
Issue Authors
- Lez-3f (1)
Pull Request Authors
- optimass (3)
- dependabot[bot] (1)
- lebrice (1)
Top Labels
Issue Labels
Pull Request Labels
Dependencies
- absl-py ==1.0.0
- aiohttp ==3.8.1
- aiosignal ==1.2.0
- async-timeout ==4.0.2
- attrs ==21.4.0
- cachetools ==5.0.0
- cffi ==1.15.0
- charset-normalizer ==2.0.12
- click ==8.1.3
- cloudpickle ==2.0.0
- continuum ==1.0.19
- cycler ==0.11.0
- cython ==0.29.28
- docker-pycreds ==0.4.0
- docopt ==0.6.2
- execnet ==1.9.0
- fasteners ==0.17.3
- fonttools ==4.33.3
- frozenlist ==1.3.0
- fsspec ==2022.3.0
- future ==0.18.2
- gitdb ==4.0.9
- gitpython ==3.1.27
- glfw ==2.5.3
- google-auth ==2.6.6
- google-auth-oauthlib ==0.4.6
- grpcio ==1.44.0
- gym ==0.21.0
- h5py ==3.6.0
- idna ==3.3
- imageio ==2.18.0
- importlib-metadata ==4.11.3
- iniconfig ==1.1.1
- joblib ==1.1.0
- kiwisolver ==1.4.2
- lightning-bolts ==0.5.0
- markdown ==3.3.6
- matplotlib ==3.5.1
- metaworld ==0.1.0
- mujoco-py ==2.1.2.14
- multidict ==6.0.2
- mypy-extensions ==0.4.3
- nngeometry ==0.1
- numpy ==1.22.3
- oauthlib ==3.2.0
- packaging ==21.3
- pandas ==1.4.2
- pathtools ==0.1.2
- pillow ==9.1.0
- pipreqs ==0.4.11
- plotly ==5.7.0
- pluggy ==1.0.0
- promise ==2.3
- protobuf ==3.20.1
- psutil ==5.9.0
- py ==1.11.0
- pyasn1 ==0.4.8
- pyasn1-modules ==0.2.8
- pycparser ==2.21
- pydeprecate ==0.3.1
- pyparsing ==3.0.8
- pytest ==7.1.2
- pytest-forked ==1.4.0
- pytest-timeout ==2.1.0
- pytest-xdist ==2.5.0
- pytest-xvfb ==2.0.0
- python-dateutil ==2.8.2
- pytorch-lightning ==1.5.9
- pytz ==2022.1
- pyvirtualdisplay ==3.0
- pyyaml ==5.3.1
- requests ==2.27.1
- requests-oauthlib ==1.3.1
- rsa ==4.8
- scikit-learn ==1.0.2
- scipy ==1.8.0
- seaborn ==0.11.2
- sentry-sdk ==1.5.10
- setproctitle ==1.2.3
- setuptools ==59.5.0
- shortuuid ==1.0.8
- simple-parsing ==0.0.19.post1
- six ==1.16.0
- smmap ==5.0.0
- tenacity ==8.0.1
- tensorboard ==2.8.0
- tensorboard-data-server ==0.6.1
- tensorboard-plugin-wit ==1.8.1
- threadpoolctl ==3.1.0
- tomli ==2.0.1
- torch ==1.8.1
- torchmetrics ==0.8.1
- torchvision ==0.9.1
- tqdm ==4.64.0
- typing-extensions ==4.2.0
- typing-inspect ==0.7.1
- urllib3 ==1.26.9
- wandb ==0.12.15
- werkzeug ==2.1.2
- yarg ==0.1.9
- yarl ==1.7.2
- zipp ==3.8.0