https://github.com/amazon-science/replay-based-recurrent-rl

Code for "Task-Agnostic Continual RL: In Praise of a Simple Baseline"

https://github.com/amazon-science/replay-based-recurrent-rl

Science Score: 36.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Committers with academic emails
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (9.9%) to scientific vocabulary

Keywords

meta-learning multi-task-learning reinforcement-learning
Last synced: 5 months ago · JSON representation

Repository

Code for "Task-Agnostic Continual RL: In Praise of a Simple Baseline"

Basic Info
  • Host: GitHub
  • Owner: amazon-science
  • License: apache-2.0
  • Language: Python
  • Default Branch: main
  • Homepage:
  • Size: 100 MB
Statistics
  • Stars: 34
  • Watchers: 2
  • Forks: 6
  • Open Issues: 0
  • Releases: 0
Topics
meta-learning multi-task-learning reinforcement-learning
Created over 3 years ago · Last pushed over 2 years ago
Metadata Files
Readme Contributing License Code of conduct

README.md

Task-Agnostic Continual RL: In Praise of a Simple Baseline

Table of Contents
  1. About The Project
  2. Structure
  3. Installation
  4. Usage
  5. Contributing
  6. License
  7. Contact
  8. Acknowledgments

About The Project

Official codebase for the paper Task-Agnostic Continual Reinforcement Learning: Gaining Insights and Overcoming Challenges. The code can be useful to run continual RL (or multi-task RL) experiments in Meta-World (e.g. in Continual-World) as well as large-scale study in the synthetic benchmark Quadratic Optimization. The baselines, including replay-based recurrent RL (3RL) and ER-TX, are modular enough to be ported into another codebase as well.

If you have found the paper or codebase useful, consider citing the work: @article{caccia2022task, title={Task-Agnostic Continual Reinforcement Learning: Gaining Insights and Overcoming Challenges}, author={Caccia, Massimo and Mueller, Jonas and Kim, Taesup and Charlin, Laurent and Fakoor, Rasool}, journal={arXiv preprint arXiv:2205.14495}, year={2022} }

(back to top)

Built With

(back to top)

Structure

 code
     algs/SAC
         sac.py                   # training SAC   
     configs
         hparams                  # hyperparameter config files
         methods                  # methods' config files
         settings                 # settings config files
     misc
         buffer.py                # sampling data from buffer
         runner_offpolicy.py      # agent sampling data from the env
         sequoia_envs.py          # creating the environemnts
         utils.py
     models
         networks.py              # creates the neural networks
     scripts
         test_codebase.py         # makes sures the repo runs correctly
     train_and_eval
         eval.py                  # evaluation logig
         train_cl.py              # training logic in CL
         train_mtl.py             # training logic in MTL
     main.py                      # main file for CRL experiments
 public_ck                  
     ...                          # checkpoints in CW10 benchmark

(back to top)

Installation

Essentially what you need is - python (3.8) - sequoia - mujoco_py - pytorch

It can be quite tricky to install mujoco_py, as well as running Meta-World. For this reason, we've used Sequoia to create the continual reinforcement learning environments.

Here's how you can install the dependencies in MacOS (BigSur)

  1. create a env, ideally w/ conda bash conda create -n tacrl python=3.8 conda activate tacrl

  2. install Sequoia w/ Meta-World add-on bash pip install "sequoia[metaworld] @ git+https://www.github.com/lebrice/Sequoia.git@pass_seed_to_metaworld_envs"

  3. extra requirements bash pip install -r requirements.txt

  4. install mujoco + mujoco key You will need to install MuJoCo (version >= 2.1)

UPDATE: I haven't reinstalled Mujoco since DeepMind acquisition and refactoring. Best of luck w/ the installation.

You can follow RoboSuite installation if you stumble on some GCC related bugs (MacOS specific).

For GCC / GL/glew.h related errors, you can use the instructions here

Contact us if you have any problem!

(back to top)

Usage

example of running SAC w/ an RNN in CW10 python python code/main.py --train_mode cl --context_rnn True --setting_config CW10 --lr 0.0003 --batch_size 1028 or w/ a transformer in Quadratic Optimization in a multi-task regime python python code/main.py --train_mode mtl --context_transformer True --env_name Quadratic_opt --lr 0.0003 --batch_size 128

You can pass config files and reproduce the paper's results by combining a setting, method and hyperparameters config file in the following manner python python code/main.py --train_mode <cl, mtl> --setting_config <setting> --method_config <method> --hparam_config <hyperparameters> e.g. running 3RL in CW10 w/ the hyperparameter prescribed by Meta-World (for Meta-world v2): python python code/main.py --train_mode cl --setting_config CW10 --method_config 3RL --hparam_config meta_world

For the MTRL experiments, run python python code/main.py --train_mode mtl

for prototyping, you can use the ant_direction environment: python python code/main.py --env_name Ant_direction-v3

Note: If you get an error about "yaml.safe_load", replace it with "yaml.load()".

Paper reproduction

For access to the WandB project w/ the results, please contact me.

For Figure 1, use analyse_reps.py (the models are in public_ck)

For all synthetic data experiments, you can create a wandb sweep bash wandb sweep --project Quadratic_opt code/configs/sweeps/Quadratic_opt.yaml And then launch the wandb agent bash wandb agent <sweep_id>

For Figure 5 python python main.py --train_mode cl --setting_config <CW10, CL_MW20_500k> --method_config <method> --hparam_config meta_world

For Figure 7 python python main.py --train_mode mtl --setting_config MTL_MW20 --method_config <method> --hparam_config meta_world-noet

(back to top)

License

This project is licensed under the Apache-2.0 License.

Contact

Please open an issue on issues tracker to report problems or to ask questions or send an email to Massimo Caccia - @MassCaccia and Rasool Fakoor.

(back to top)

Owner

  • Name: Amazon Science
  • Login: amazon-science
  • Kind: organization

GitHub Events

Total
  • Watch event: 3
Last Year
  • Watch event: 3

Committers

Last synced: 8 months ago

All Time
  • Total Commits: 13
  • Total Committers: 5
  • Avg Commits per committer: 2.6
  • Development Distribution Score (DDS): 0.692
Past Year
  • Commits: 0
  • Committers: 0
  • Avg Commits per committer: 0.0
  • Development Distribution Score (DDS): 0.0
Top Committers
Name Email Commits
Massimo Caccia m****a@g****m 4
rasoolfa r****r@g****m 4
rasoolfakoor 7****r 3
Fabrice Normandin n****f@m****c 1
Amazon GitHub Automation 5****o 1
Committer Domains (Top 20 + Academic)

Issues and Pull Requests

Last synced: 8 months ago

All Time
  • Total issues: 1
  • Total pull requests: 5
  • Average time to close issues: about 1 month
  • Average time to close pull requests: 3 days
  • Total issue authors: 1
  • Total pull request authors: 3
  • Average comments per issue: 7.0
  • Average comments per pull request: 0.4
  • Merged pull requests: 4
  • Bot issues: 0
  • Bot pull requests: 1
Past Year
  • Issues: 0
  • Pull requests: 0
  • Average time to close issues: N/A
  • Average time to close pull requests: N/A
  • Issue authors: 0
  • Pull request authors: 0
  • Average comments per issue: 0
  • Average comments per pull request: 0
  • Merged pull requests: 0
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • Lez-3f (1)
Pull Request Authors
  • optimass (3)
  • dependabot[bot] (1)
  • lebrice (1)
Top Labels
Issue Labels
Pull Request Labels
dependencies (1)

Dependencies

environment.yml pypi
  • absl-py ==1.0.0
  • aiohttp ==3.8.1
  • aiosignal ==1.2.0
  • async-timeout ==4.0.2
  • attrs ==21.4.0
  • cachetools ==5.0.0
  • cffi ==1.15.0
  • charset-normalizer ==2.0.12
  • click ==8.1.3
  • cloudpickle ==2.0.0
  • continuum ==1.0.19
  • cycler ==0.11.0
  • cython ==0.29.28
  • docker-pycreds ==0.4.0
  • docopt ==0.6.2
  • execnet ==1.9.0
  • fasteners ==0.17.3
  • fonttools ==4.33.3
  • frozenlist ==1.3.0
  • fsspec ==2022.3.0
  • future ==0.18.2
  • gitdb ==4.0.9
  • gitpython ==3.1.27
  • glfw ==2.5.3
  • google-auth ==2.6.6
  • google-auth-oauthlib ==0.4.6
  • grpcio ==1.44.0
  • gym ==0.21.0
  • h5py ==3.6.0
  • idna ==3.3
  • imageio ==2.18.0
  • importlib-metadata ==4.11.3
  • iniconfig ==1.1.1
  • joblib ==1.1.0
  • kiwisolver ==1.4.2
  • lightning-bolts ==0.5.0
  • markdown ==3.3.6
  • matplotlib ==3.5.1
  • metaworld ==0.1.0
  • mujoco-py ==2.1.2.14
  • multidict ==6.0.2
  • mypy-extensions ==0.4.3
  • nngeometry ==0.1
  • numpy ==1.22.3
  • oauthlib ==3.2.0
  • packaging ==21.3
  • pandas ==1.4.2
  • pathtools ==0.1.2
  • pillow ==9.1.0
  • pipreqs ==0.4.11
  • plotly ==5.7.0
  • pluggy ==1.0.0
  • promise ==2.3
  • protobuf ==3.20.1
  • psutil ==5.9.0
  • py ==1.11.0
  • pyasn1 ==0.4.8
  • pyasn1-modules ==0.2.8
  • pycparser ==2.21
  • pydeprecate ==0.3.1
  • pyparsing ==3.0.8
  • pytest ==7.1.2
  • pytest-forked ==1.4.0
  • pytest-timeout ==2.1.0
  • pytest-xdist ==2.5.0
  • pytest-xvfb ==2.0.0
  • python-dateutil ==2.8.2
  • pytorch-lightning ==1.5.9
  • pytz ==2022.1
  • pyvirtualdisplay ==3.0
  • pyyaml ==5.3.1
  • requests ==2.27.1
  • requests-oauthlib ==1.3.1
  • rsa ==4.8
  • scikit-learn ==1.0.2
  • scipy ==1.8.0
  • seaborn ==0.11.2
  • sentry-sdk ==1.5.10
  • setproctitle ==1.2.3
  • setuptools ==59.5.0
  • shortuuid ==1.0.8
  • simple-parsing ==0.0.19.post1
  • six ==1.16.0
  • smmap ==5.0.0
  • tenacity ==8.0.1
  • tensorboard ==2.8.0
  • tensorboard-data-server ==0.6.1
  • tensorboard-plugin-wit ==1.8.1
  • threadpoolctl ==3.1.0
  • tomli ==2.0.1
  • torch ==1.8.1
  • torchmetrics ==0.8.1
  • torchvision ==0.9.1
  • tqdm ==4.64.0
  • typing-extensions ==4.2.0
  • typing-inspect ==0.7.1
  • urllib3 ==1.26.9
  • wandb ==0.12.15
  • werkzeug ==2.1.2
  • yarg ==0.1.9
  • yarl ==1.7.2
  • zipp ==3.8.0