Science Score: 44.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (11.9%) to scientific vocabulary
Last synced: 10 months ago · JSON representation ·

Repository

Basic Info
  • Host: GitHub
  • Owner: zheng-zf
  • License: apache-2.0
  • Language: Python
  • Default Branch: main
  • Size: 550 KB
Statistics
  • Stars: 0
  • Watchers: 1
  • Forks: 0
  • Open Issues: 0
  • Releases: 0
Created over 1 year ago · Last pushed over 1 year ago
Metadata Files
Readme Contributing License Code of conduct Citation

README.md

TRL - Transformer Reinforcement Learning

Full stack library to fine-tune and align large language models.

License Documentation GitHub release

What is it?

The trl library is a full stack tool to fine-tune and align transformer language and diffusion models using methods such as Supervised Fine-tuning step (SFT), Reward Modeling (RM) and the Proximal Policy Optimization (PPO) as well as Direct Preference Optimization (DPO).

The library is built on top of the transformers library and thus allows to use any model architecture available there.

Highlights

  • Efficient and scalable:
    • accelerate is the backbone of trl which allows to scale model training from a single GPU to a large scale multi-node cluster with methods such as DDP and DeepSpeed.
    • PEFT is fully integrated and allows to train even the largest models on modest hardware with quantisation and methods such as LoRA or QLoRA.
    • unsloth is also integrated and allows to significantly speed up training with dedicated kernels.
  • CLI: With the CLI you can fine-tune and chat with LLMs without writing any code using a single command and a flexible config system.
  • Trainers: The Trainer classes are an abstraction to apply many fine-tuning methods with ease such as the SFTTrainer, DPOTrainer, RewardTrainer, PPOTrainer, CPOTrainer, and ORPOTrainer.
  • AutoModels: The AutoModelForCausalLMWithValueHead & AutoModelForSeq2SeqLMWithValueHead classes add an additional value head to the model which allows to train them with RL algorithms such as PPO.
  • Examples: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, StackLlama example, etc. following the examples.

Installation

Python package

Install the library with pip: bash pip install trl

From source

If you want to use the latest features before an official release you can install from source: bash pip install git+https://github.com/huggingface/trl.git

Repository

If you want to use the examples you can clone the repository with the following command: bash git clone https://github.com/huggingface/trl.git

Command Line Interface (CLI)

You can use TRL Command Line Interface (CLI) to quickly get started with Supervised Fine-tuning (SFT), Direct Preference Optimization (DPO) and test your aligned model with the chat CLI:

SFT:

bash trl sft --model_name_or_path facebook/opt-125m --dataset_name stanfordnlp/imdb --output_dir opt-sft-imdb

DPO:

bash trl dpo --model_name_or_path facebook/opt-125m --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style --output_dir opt-sft-hh-rlhf

Chat:

bash trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat

Read more about CLI in the relevant documentation section or use --help for more details.

How to use

For more flexibility and control over the training, you can use the dedicated trainer classes to fine-tune the model in Python.

SFTTrainer

This is a basic example of how to use the SFTTrainer from the library. The SFTTrainer is a light wrapper around the transformers Trainer to easily fine-tune language models or adapters on a custom dataset.

```python

imports

from datasets import load_dataset from trl import SFTTrainer

get dataset

dataset = load_dataset("stanfordnlp/imdb", split="train")

get trainer

trainer = SFTTrainer( "facebook/opt-350m", traindataset=dataset, datasettextfield="text", maxseq_length=512, )

train

trainer.train() ```

RewardTrainer

This is a basic example of how to use the RewardTrainer from the library. The RewardTrainer is a wrapper around the transformers Trainer to easily fine-tune reward models or adapters on a custom preference dataset.

```python

imports

from transformers import AutoModelForSequenceClassification, AutoTokenizer from trl import RewardTrainer

load model and dataset - dataset needs to be in a specific format

model = AutoModelForSequenceClassification.frompretrained("gpt2", numlabels=1) tokenizer = AutoTokenizer.from_pretrained("gpt2")

...

load trainer

trainer = RewardTrainer( model=model, tokenizer=tokenizer, train_dataset=dataset, )

train

trainer.train() ```

PPOTrainer

This is a basic example of how to use the PPOTrainer from the library. Based on a query the language model creates a response which is then evaluated. The evaluation could be a human in the loop or another model's output.

```python

imports

import torch from transformers import AutoTokenizer from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, createreferencemodel from trl.core import respondtobatch

get models

model = AutoModelForCausalLMWithValueHead.frompretrained('gpt2') refmodel = createreferencemodel(model)

tokenizer = AutoTokenizer.frompretrained('gpt2') tokenizer.padtoken = tokenizer.eos_token

initialize trainer

ppoconfig = PPOConfig(batchsize=1, minibatchsize=1)

encode a query

querytxt = "This morning I went to the " querytensor = tokenizer.encode(querytxt, returntensors="pt")

get model response

responsetensor = respondtobatch(model, querytensor)

create a ppo trainer

ppotrainer = PPOTrainer(ppoconfig, model, ref_model, tokenizer)

define a reward for response

(this could be any reward such as human feedback or output from another model)

reward = [torch.tensor(1.0)]

train model for one step with ppo

trainstats = ppotrainer.step([querytensor[0]], [responsetensor[0]], reward) ```

DPOTrainer

DPOTrainer is a trainer that uses Direct Preference Optimization algorithm. This is a basic example of how to use the DPOTrainer from the library. The DPOTrainer is a wrapper around the transformers Trainer to easily fine-tune reward models or adapters on a custom preference dataset.

```python

imports

from transformers import AutoModelForCausalLM, AutoTokenizer from trl import DPOTrainer

load model and dataset - dataset needs to be in a specific format

model = AutoModelForCausalLM.frompretrained("gpt2") tokenizer = AutoTokenizer.frompretrained("gpt2")

...

load trainer

trainer = DPOTrainer( model=model, tokenizer=tokenizer, train_dataset=dataset, )

train

trainer.train() ```

Development

If you want to contribute to trl or customizing it to your needs make sure to read the contribution guide and make sure you make a dev install:

bash git clone https://github.com/huggingface/trl.git cd trl/ make dev

References

Proximal Policy Optimisation

The PPO implementation largely follows the structure introduced in the paper "Fine-Tuning Language Models from Human Preferences" by D. Ziegler et al. [paper, code].

Direct Preference Optimization

DPO is based on the original implementation of "Direct Preference Optimization: Your Language Model is Secretly a Reward Model" by E. Mitchell et al. [paper, code]

Citation

bibtex @misc{vonwerra2022trl, author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang}, title = {TRL: Transformer Reinforcement Learning}, year = {2020}, publisher = {GitHub}, journal = {GitHub repository}, howpublished = {\url{https://github.com/huggingface/trl}} }

trl

Owner

  • Login: zheng-zf
  • Kind: user

Citation (CITATION.cff)

cff-version: 1.2.0
title: 'TRL: Transformer Reinforcement Learning'
message: >-
  If you use this software, please cite it using the
  metadata from this file.
type: software
authors:
  - given-names: Leandro
    family-names: von Werra
  - given-names: Younes
    family-names: Belkada
  - given-names: Lewis
    family-names: Tunstall
  - given-names: Edward
    family-names: Beeching
  - given-names: Tristan
    family-names: Thrush
  - given-names: Nathan
    family-names: Lambert
repository-code: 'https://github.com/huggingface/trl'
abstract: "With trl you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the transformers library by \U0001F917 Hugging Face. Therefore, pre-trained language models can be directly loaded via transformers. At this point, most decoder and encoder-decoder architectures are supported."
keywords:
  - rlhf
  - deep-learning
  - pytorch
  - transformers
license: Apache-2.0
version: 0.2.1

GitHub Events

Total
  • Push event: 2
  • Create event: 1
Last Year
  • Push event: 2
  • Create event: 1

Dependencies

.github/workflows/build_documentation.yml actions
.github/workflows/build_pr_documentation.yml actions
.github/workflows/clear_cache.yml actions
  • actions/checkout v4 composite
.github/workflows/docker-build.yml actions
  • actions/checkout v4 composite
  • docker/build-push-action v4 composite
  • docker/login-action v1 composite
  • docker/setup-buildx-action v1 composite
  • huggingface/hf-workflows/.github/actions/post-slack main composite
.github/workflows/slow-tests.yml actions
  • actions/checkout v4 composite
.github/workflows/stale.yml actions
  • actions/checkout v4 composite
  • actions/setup-python v5 composite
.github/workflows/tests-main.yml actions
  • actions/checkout v4 composite
  • actions/setup-python v5 composite
  • huggingface/hf-workflows/.github/actions/post-slack main composite
.github/workflows/tests.yml actions
  • actions/checkout v4 composite
  • actions/setup-python v5 composite
  • pre-commit/action v3.0.1 composite
.github/workflows/trufflehog.yml actions
  • actions/checkout v4 composite
  • trufflesecurity/trufflehog main composite
.github/workflows/upload_pr_documentation.yml actions
docker/trl-latest-gpu/Dockerfile docker
  • continuumio/miniconda3 latest build
  • nvidia/cuda 12.2.2-devel-ubuntu22.04 build
docker/trl-source-gpu/Dockerfile docker
  • continuumio/miniconda3 latest build
  • nvidia/cuda 12.2.2-devel-ubuntu22.04 build
examples/research_projects/stack_llama_2/scripts/requirements.txt pypi
  • accelerate *
  • bitsandbytes *
  • datasets *
  • peft *
  • transformers *
  • trl *
  • wandb *
pyproject.toml pypi
requirements.txt pypi
  • accelerate *
  • datasets >=1.17.0
  • peft >=0.3.0
  • torch >=1.4.0
  • tqdm *
  • transformers >=4.40.0
  • tyro >=0.5.7
setup.py pypi