Science Score: 44.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (15.7%) to scientific vocabulary
Last synced: 6 months ago · JSON representation ·

Repository

Basic Info
  • Host: GitHub
  • Owner: hardbrah
  • License: apache-2.0
  • Language: Python
  • Default Branch: main
  • Size: 7.38 MB
Statistics
  • Stars: 0
  • Watchers: 1
  • Forks: 0
  • Open Issues: 0
  • Releases: 0
Created 9 months ago · Last pushed 9 months ago
Metadata Files
Readme Contributing License Code of conduct Citation

README.md

TRL - Transformer Reinforcement Learning

TRL Banner



A comprehensive library to post-train foundation models

License Documentation GitHub release

Overview

TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO). Built on top of the 🤗 Transformers ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups.

Highlights

  • Efficient and scalable:

    • Leverages 🤗 Accelerate to scale from single GPU to multi-node clusters using methods like DDP and DeepSpeed.
    • Full integration with PEFT enables training on large models with modest hardware via quantization and LoRA/QLoRA.
    • Integrates Unsloth for accelerating training using optimized kernels.
  • Command Line Interface (CLI): A simple interface lets you fine-tune and interact with models without needing to write code.

  • Trainers: Various fine-tuning methods are easily accessible via trainers like SFTTrainer, DPOTrainer, RewardTrainer, ORPOTrainer and more.

  • AutoModels: Use pre-defined model classes like AutoModelForCausalLMWithValueHead to simplify reinforcement learning (RL) with LLMs.

Installation

Python Package

Install the library using pip:

bash pip install trl

From source

If you want to use the latest features before an official release, you can install TRL from source:

bash pip install git+https://github.com/huggingface/trl.git

Repository

If you want to use the examples you can clone the repository with the following command:

bash git clone https://github.com/huggingface/trl.git

Command Line Interface (CLI)

You can use the TRL Command Line Interface (CLI) to quickly get started with Supervised Fine-tuning (SFT) and Direct Preference Optimization (DPO), or vibe check your model with the chat CLI:

SFT:

bash trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \ --dataset_name trl-lib/Capybara \ --output_dir Qwen2.5-0.5B-SFT

DPO:

bash trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ --dataset_name argilla/Capybara-Preferences \ --output_dir Qwen2.5-0.5B-DPO

Chat:

bash trl chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct

Read more about CLI in the relevant documentation section or use --help for more details.

How to use

For more flexibility and control over training, TRL provides dedicated trainer classes to post-train language models or PEFT adapters on a custom dataset. Each trainer in TRL is a light wrapper around the 🤗 Transformers trainer and natively supports distributed training methods like DDP, DeepSpeed ZeRO, and FSDP.

SFTTrainer

Here is a basic example of how to use the SFTTrainer:

```python from trl import SFTConfig, SFTTrainer from datasets import load_dataset

dataset = load_dataset("trl-lib/Capybara", split="train")

trainingargs = SFTConfig(outputdir="Qwen/Qwen2.5-0.5B-SFT") trainer = SFTTrainer( args=trainingargs, model="Qwen/Qwen2.5-0.5B", traindataset=dataset, ) trainer.train() ```

RewardTrainer

Here is a basic example of how to use the RewardTrainer:

```python from trl import RewardConfig, RewardTrainer from datasets import load_dataset from transformers import AutoModelForSequenceClassification, AutoTokenizer

tokenizer = AutoTokenizer.frompretrained("Qwen/Qwen2.5-0.5B-Instruct") model = AutoModelForSequenceClassification.frompretrained( "Qwen/Qwen2.5-0.5B-Instruct", numlabels=1 ) model.config.padtokenid = tokenizer.padtoken_id

dataset = loaddataset("trl-lib/ultrafeedbackbinarized", split="train")

trainingargs = RewardConfig(outputdir="Qwen2.5-0.5B-Reward", perdevicetrainbatchsize=2) trainer = RewardTrainer( args=trainingargs, model=model, processingclass=tokenizer, train_dataset=dataset, ) trainer.train() ```

RLOOTrainer

RLOOTrainer implements a REINFORCE-style optimization for RLHF that is more performant and memory-efficient than PPO. Here is a basic example of how to use the RLOOTrainer:

```python from trl import RLOOConfig, RLOOTrainer, applychattemplate from datasets import load_dataset from transformers import ( AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, )

tokenizer = AutoTokenizer.frompretrained("Qwen/Qwen2.5-0.5B-Instruct") rewardmodel = AutoModelForSequenceClassification.frompretrained( "Qwen/Qwen2.5-0.5B-Instruct", numlabels=1 ) refpolicy = AutoModelForCausalLM.frompretrained("Qwen/Qwen2.5-0.5B-Instruct") policy = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")

dataset = loaddataset("trl-lib/ultrafeedback-prompt") dataset = dataset.map(applychattemplate, fnkwargs={"tokenizer": tokenizer}) dataset = dataset.map(lambda x: tokenizer(x["prompt"]), remove_columns="prompt")

trainingargs = RLOOConfig(outputdir="Qwen2.5-0.5B-RL") trainer = RLOOTrainer( config=trainingargs, processingclass=tokenizer, policy=policy, refpolicy=refpolicy, rewardmodel=rewardmodel, traindataset=dataset["train"], evaldataset=dataset["test"], ) trainer.train() ```

DPOTrainer

DPOTrainer implements the popular Direct Preference Optimization (DPO) algorithm that was used to post-train Llama 3 and many other models. Here is a basic example of how to use the DPOTrainer:

```python from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer from trl import DPOConfig, DPOTrainer

model = AutoModelForCausalLM.frompretrained("Qwen/Qwen2.5-0.5B-Instruct") tokenizer = AutoTokenizer.frompretrained("Qwen/Qwen2.5-0.5B-Instruct") dataset = loaddataset("trl-lib/ultrafeedbackbinarized", split="train") trainingargs = DPOConfig(outputdir="Qwen2.5-0.5B-DPO") trainer = DPOTrainer(model=model, args=trainingargs, traindataset=dataset, processing_class=tokenizer) trainer.train() ```

Development

If you want to contribute to trl or customize it to your needs make sure to read the contribution guide and make sure you make a dev install:

bash git clone https://github.com/huggingface/trl.git cd trl/ make dev

Citation

bibtex @misc{vonwerra2022trl, author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec}, title = {TRL: Transformer Reinforcement Learning}, year = {2020}, publisher = {GitHub}, journal = {GitHub repository}, howpublished = {\url{https://github.com/huggingface/trl}} }

License

This repository's source code is available under the Apache-2.0 License.

Owner

  • Name: Chen Haotian
  • Login: hardbrah
  • Kind: user
  • Location: Shanghai, China
  • Company: Fudan University

student of Fudan University @Fudan University

Citation (CITATION.cff)

cff-version: 1.2.0
title: 'TRL: Transformer Reinforcement Learning'
message: >-
  If you use this software, please cite it using the
  metadata from this file.
type: software
authors:
  - given-names: Leandro
    family-names: von Werra
  - given-names: Younes
    family-names: Belkada
  - given-names: Lewis
    family-names: Tunstall
  - given-names: Edward
    family-names: Beeching
  - given-names: Tristan
    family-names: Thrush
  - given-names: Nathan
    family-names: Lambert
  - given-names: Shengyi
    family-names: Huang
  - given-names: Kashif
    family-names: Rasul
  - given-names: Quentin
    family-names: Gallouédec
repository-code: 'https://github.com/huggingface/trl'
abstract: "With trl you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the transformers library by \U0001F917 Hugging Face. Therefore, pre-trained language models can be directly loaded via transformers. At this point, most decoder and encoder-decoder architectures are supported."
keywords:
  - rlhf
  - deep-learning
  - pytorch
  - transformers
license: Apache-2.0
version: 0.11.1

GitHub Events

Total
  • Push event: 4
Last Year
  • Push event: 4

Dependencies

.github/workflows/build_documentation.yml actions
.github/workflows/build_pr_documentation.yml actions
.github/workflows/clear_cache.yml actions
  • actions/checkout v4 composite
.github/workflows/docker-build.yml actions
  • actions/checkout v4 composite
  • docker/build-push-action v4 composite
  • docker/login-action v1 composite
  • docker/setup-buildx-action v1 composite
  • huggingface/hf-workflows/.github/actions/post-slack main composite
.github/workflows/slow-tests.yml actions
  • actions/checkout v4 composite
.github/workflows/stale.yml actions
  • actions/checkout v4 composite
  • actions/setup-python v5 composite
.github/workflows/tests-main.yml actions
  • actions/checkout v4 composite
  • actions/setup-python v5 composite
  • huggingface/hf-workflows/.github/actions/post-slack main composite
.github/workflows/tests.yml actions
  • actions/checkout v4 composite
  • actions/setup-python v5 composite
  • pre-commit/action v3.0.1 composite
.github/workflows/trufflehog.yml actions
  • actions/checkout v4 composite
  • trufflesecurity/trufflehog main composite
.github/workflows/upload_pr_documentation.yml actions
docker/trl-latest-gpu/Dockerfile docker
  • continuumio/miniconda3 latest build
  • nvidia/cuda 12.2.2-devel-ubuntu22.04 build
docker/trl-source-gpu/Dockerfile docker
  • continuumio/miniconda3 latest build
  • nvidia/cuda 12.2.2-devel-ubuntu22.04 build
examples/research_projects/stack_llama_2/scripts/requirements.txt pypi
  • accelerate *
  • bitsandbytes *
  • datasets *
  • peft *
  • transformers *
  • trl *
  • wandb *
pyproject.toml pypi
requirements.txt pypi
  • accelerate *
  • datasets >=1.17.0
  • peft >=0.3.0
  • torch >=1.4.0
  • tqdm *
  • transformers >=4.40.0
  • tyro >=0.5.7
setup.py pypi