Science Score: 54.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (14.1%) to scientific vocabulary
Last synced: 6 months ago · JSON representation ·

Repository

Basic Info
  • Host: GitHub
  • Owner: jurgen-paul
  • License: apache-2.0
  • Language: Python
  • Default Branch: add_spin
  • Size: 41 MB
Statistics
  • Stars: 0
  • Watchers: 0
  • Forks: 0
  • Open Issues: 0
  • Releases: 0
Created 8 months ago · Last pushed 8 months ago
Metadata Files
Readme Contributing License Citation Security

README.md

TRL - Transformer Reinforcement Learning

Full stack transformer language models with reinforcement learning.

License Documentation GitHub release

What is it?

trl is a full stack library where we provide a set of tools to train transformer language models and stable diffusion models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step. The library is built on top of the transformers library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via transformers. At this point, most of decoder architectures and encoder-decoder architectures are supported. Refer to the documentation or the examples/ folder for example code snippets and how to run these tools.

Highlights:

  • SFTTrainer: A light and friendly wrapper around transformers Trainer to easily fine-tune language models or adapters on a custom dataset.
  • RewardTrainer: A light wrapper around transformers Trainer to easily fine-tune language models for human preferences (Reward Modeling).
  • PPOTrainer: A PPO trainer for language models that just needs (query, response, reward) triplets to optimise the language model.
  • AutoModelForCausalLMWithValueHead & AutoModelForSeq2SeqLMWithValueHead: A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
  • Examples: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, Stack-Llama example, etc.

How PPO works

Fine-tuning a language model via PPO consists of roughly three steps:

  1. Rollout: The language model generates a response or continuation based on query which could be the start of a sentence.
  2. Evaluation: The query and response are evaluated with a function, model, human feedback or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair.
  3. Optimization: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO.

This process is illustrated in the sketch below:

Figure: Sketch of the workflow.

Installation

Python package

Install the library with pip: bash pip install trl

From source

If you want to run the examples in the repository a few additional libraries are required. Clone the repository and install it with pip: bash git clone https://github.com/huggingface/trl.git cd trl/ pip install .

If you wish to develop TRL, you should install in editable mode: bash pip install -e .

How to use

SFTTrainer

This is a basic example on how to use the SFTTrainer from the library. The SFTTrainer is a light wrapper around the transformers Trainer to easily fine-tune language models or adapters on a custom dataset.

```python

imports

from datasets import load_dataset from trl import SFTTrainer

get dataset

dataset = load_dataset("imdb", split="train")

get trainer

trainer = SFTTrainer( "facebook/opt-350m", traindataset=dataset, datasettextfield="text", maxseq_length=512, )

train

trainer.train() ```

RewardTrainer

This is a basic example on how to use the RewardTrainer from the library. The RewardTrainer is a wrapper around the transformers Trainer to easily fine-tune reward models or adapters on a custom preference dataset.

```python

imports

from transformers import AutoModelForSequenceClassification, AutoTokenizer from trl import RewardTrainer

load model and dataset - dataset needs to be in a specific format

model = AutoModelForSequenceClassification.frompretrained("gpt2", numlabels=1) tokenizer = AutoTokenizer.from_pretrained("gpt2")

...

load trainer

trainer = RewardTrainer( model=model, tokenizer=tokenizer, train_dataset=dataset, )

train

trainer.train() ```

PPOTrainer

This is a basic example on how to use the PPOTrainer from the library. Based on a query the language model creates a response which is then evaluated. The evaluation could be a human in the loop or another model's output.

```python

imports

import torch from transformers import AutoTokenizer from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, createreferencemodel from trl.core import respondtobatch

get models

model = AutoModelForCausalLMWithValueHead.frompretrained('gpt2') modelref = createreferencemodel(model)

tokenizer = AutoTokenizer.from_pretrained('gpt2')

initialize trainer

ppoconfig = PPOConfig( batchsize=1, )

encode a query

querytxt = "This morning I went to the " querytensor = tokenizer.encode(querytxt, returntensors="pt")

get model response

responsetensor = respondtobatch(model, querytensor)

create a ppo trainer

ppotrainer = PPOTrainer(ppoconfig, model, model_ref, tokenizer)

define a reward for response

(this could be any reward such as human feedback or output from another model)

reward = [torch.tensor(1.0)]

train model for one step with ppo

trainstats = ppotrainer.step([querytensor[0]], [responsetensor[0]], reward) ```

References

Proximal Policy Optimisation

The PPO implementation largely follows the structure introduced in the paper "Fine-Tuning Language Models from Human Preferences" by D. Ziegler et al. [paper, code].

Language models

The language models utilize the transformers library by 🤗 Hugging Face.

Citation

bibtex @misc{vonwerra2022trl, author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang}, title = {TRL: Transformer Reinforcement Learning}, year = {2020}, publisher = {GitHub}, journal = {GitHub repository}, howpublished = {\url{https://github.com/huggingface/trl}} }

Owner

  • Login: jurgen-paul
  • Kind: user

Citation (CITATION.cff)

cff-version: 1.2.0
title: 'TRL: Transformer Reinforcement Learning'
message: >-
  If you use this software, please cite it using the
  metadata from this file.
type: software
authors:
  - given-names: Leandro
    family-names: von Werra
  - given-names: Younes
    family-names: Belkada
  - given-names: Lewis
    family-names: Tunstall
  - given-names: Edward
    family-names: Beeching
  - given-names: Tristan
    family-names: Thrush
  - given-names: Nathan
    family-names: Lambert
repository-code: 'https://github.com/huggingface/trl'
abstract: "With trl you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the transformers library by \U0001F917 Hugging Face. Therefore, pre-trained language models can be directly loaded via transformers. At this point, most decoder and encoder-decoder architectures are supported."
keywords:
  - rlhf
  - deep-learning
  - pytorch
  - transformers
license: Apache-2.0
version: 0.2.1

GitHub Events

Total
  • Push event: 1
  • Create event: 20
Last Year
  • Push event: 1
  • Create event: 20

Dependencies

.github/workflows/benchmark.yml actions
  • actions/checkout v3 composite
  • actions/github-script v6 composite
  • actions/setup-node v3 composite
  • actions/setup-python v4 composite
  • myrotvorets/set-commit-status-action master composite
  • xt0rted/pull-request-comment-branch v1 composite
.github/workflows/build_documentation.yml actions
.github/workflows/build_pr_documentation.yml actions
.github/workflows/clear_cache.yml actions
  • actions/checkout v3 composite
.github/workflows/docker-build.yml actions
  • actions/checkout v3 composite
  • docker/build-push-action v4 composite
  • docker/login-action v1 composite
  • docker/setup-buildx-action v1 composite
.github/workflows/python-package-conda.yml actions
.github/workflows/slow-tests.yml actions
  • actions/checkout v3 composite
.github/workflows/stale.yml actions
  • actions/checkout v3 composite
  • actions/setup-python v4 composite
.github/workflows/tests.yml actions
  • actions/checkout v2 composite
  • actions/checkout v3 composite
  • actions/setup-python v2 composite
  • actions/setup-python v4 composite
  • pre-commit/action v2.0.3 composite
.github/workflows/upload_pr_documentation.yml actions
docker/trl-latest-gpu/Dockerfile docker
  • continuumio/miniconda3 latest build
  • nvidia/cuda 12.2.2-devel-ubuntu22.04 build
docker/trl-source-gpu/Dockerfile docker
  • continuumio/miniconda3 latest build
  • nvidia/cuda 12.2.2-devel-ubuntu22.04 build
examples/research_projects/stack_llama_2/scripts/requirements.txt pypi
  • accelerate *
  • bitsandbytes *
  • datasets *
  • peft *
  • transformers *
  • trl *
  • wandb *
pyproject.toml pypi
requirements.txt pypi
  • accelerate *
  • datasets >=1.17.0
  • peft >=0.3.0
  • torch >=1.4.0
  • tqdm *
  • transformers *
  • tyro >=0.5.7
setup.py pypi