Science Score: 44.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (15.5%) to scientific vocabulary
Last synced: 6 months ago · JSON representation ·

Repository

Basic Info
  • Host: GitHub
  • Owner: LiuWenlin595
  • License: apache-2.0
  • Language: Python
  • Default Branch: master
  • Size: 658 KB
Statistics
  • Stars: 0
  • Watchers: 0
  • Forks: 0
  • Open Issues: 0
  • Releases: 0
Created 9 months ago · Last pushed 8 months ago
Metadata Files
Readme Contributing License Code of conduct Citation

README.md

TRL - Transformer Reinforcement Learning

TRL Banner



A comprehensive library to post-train foundation models

License Documentation GitHub release Hugging Face Hub

Overview

TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO). Built on top of the 🤗 Transformers ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups.

Highlights

  • Trainers: Various fine-tuning methods are easily accessible via trainers like SFTTrainer, GRPOTrainer, DPOTrainer, RewardTrainer and more.

  • Efficient and scalable:

    • Leverages 🤗 Accelerate to scale from single GPU to multi-node clusters using methods like DDP and DeepSpeed.
    • Full integration with 🤗 PEFT enables training on large models with modest hardware via quantization and LoRA/QLoRA.
    • Integrates 🦥 Unsloth for accelerating training using optimized kernels.
  • Command Line Interface (CLI): A simple interface lets you fine-tune with models without needing to write code.

Installation

Python Package

Install the library using pip:

bash pip install trl

From source

If you want to use the latest features before an official release, you can install TRL from source:

bash pip install git+https://github.com/huggingface/trl.git

Repository

If you want to use the examples you can clone the repository with the following command:

bash git clone https://github.com/huggingface/trl.git

Quick Start

For more flexibility and control over training, TRL provides dedicated trainer classes to post-train language models or PEFT adapters on a custom dataset. Each trainer in TRL is a light wrapper around the 🤗 Transformers trainer and natively supports distributed training methods like DDP, DeepSpeed ZeRO, and FSDP.

SFTTrainer

Here is a basic example of how to use the SFTTrainer:

```python from trl import SFTTrainer from datasets import load_dataset

dataset = load_dataset("trl-lib/Capybara", split="train")

trainer = SFTTrainer( model="Qwen/Qwen2.5-0.5B", train_dataset=dataset, ) trainer.train() ```

GRPOTrainer

GRPOTrainer implements the Group Relative Policy Optimization (GRPO) algorithm that is more memory-efficient than PPO and was used to train Deepseek AI's R1.

```python from datasets import load_dataset from trl import GRPOTrainer

dataset = load_dataset("trl-lib/tldr", split="train")

Dummy reward function: count the number of unique characters in the completions

def rewardnumunique_chars(completions, **kwargs): return [len(set(c)) for c in completions]

trainer = GRPOTrainer( model="Qwen/Qwen2-0.5B-Instruct", rewardfuncs=rewardnumuniquechars, train_dataset=dataset, ) trainer.train() ```

DPOTrainer

DPOTrainer implements the popular Direct Preference Optimization (DPO) algorithm that was used to post-train Llama 3 and many other models. Here is a basic example of how to use the DPOTrainer:

```python from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer from trl import DPOConfig, DPOTrainer

model = AutoModelForCausalLM.frompretrained("Qwen/Qwen2.5-0.5B-Instruct") tokenizer = AutoTokenizer.frompretrained("Qwen/Qwen2.5-0.5B-Instruct") dataset = loaddataset("trl-lib/ultrafeedbackbinarized", split="train") trainingargs = DPOConfig(outputdir="Qwen2.5-0.5B-DPO") trainer = DPOTrainer( model=model, args=trainingargs, traindataset=dataset, processing_class=tokenizer ) trainer.train() ```

RewardTrainer

Here is a basic example of how to use the RewardTrainer:

```python from trl import RewardConfig, RewardTrainer from datasets import load_dataset from transformers import AutoModelForSequenceClassification, AutoTokenizer

tokenizer = AutoTokenizer.frompretrained("Qwen/Qwen2.5-0.5B-Instruct") model = AutoModelForSequenceClassification.frompretrained( "Qwen/Qwen2.5-0.5B-Instruct", numlabels=1 ) model.config.padtokenid = tokenizer.padtoken_id

dataset = loaddataset("trl-lib/ultrafeedbackbinarized", split="train")

trainingargs = RewardConfig(outputdir="Qwen2.5-0.5B-Reward", perdevicetrainbatchsize=2) trainer = RewardTrainer( args=trainingargs, model=model, processingclass=tokenizer, train_dataset=dataset, ) trainer.train() ```

Command Line Interface (CLI)

You can use the TRL Command Line Interface (CLI) to quickly get started with post-training methods like Supervised Fine-Tuning (SFT) or Direct Preference Optimization (DPO):

SFT:

bash trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \ --dataset_name trl-lib/Capybara \ --output_dir Qwen2.5-0.5B-SFT

DPO:

bash trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ --dataset_name argilla/Capybara-Preferences \ --output_dir Qwen2.5-0.5B-DPO

Read more about CLI in the relevant documentation section or use --help for more details.

Development

If you want to contribute to trl or customize it to your needs make sure to read the contribution guide and make sure you make a dev install:

bash git clone https://github.com/huggingface/trl.git cd trl/ pip install -e .[dev]

Citation

bibtex @misc{vonwerra2022trl, author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec}, title = {TRL: Transformer Reinforcement Learning}, year = {2020}, publisher = {GitHub}, journal = {GitHub repository}, howpublished = {\url{https://github.com/huggingface/trl}} }

License

This repository's source code is available under the Apache-2.0 License.

Owner

  • Name: 刘文林
  • Login: LiuWenlin595
  • Kind: user

Citation (CITATION.cff)

cff-version: 1.2.0
title: 'TRL: Transformer Reinforcement Learning'
message: >-
  If you use this software, please cite it using the
  metadata from this file.
type: software
authors:
  - given-names: Leandro
    family-names: von Werra
  - given-names: Younes
    family-names: Belkada
  - given-names: Lewis
    family-names: Tunstall
  - given-names: Edward
    family-names: Beeching
  - given-names: Tristan
    family-names: Thrush
  - given-names: Nathan
    family-names: Lambert
  - given-names: Shengyi
    family-names: Huang
  - given-names: Kashif
    family-names: Rasul
  - given-names: Quentin
    family-names: Gallouédec
repository-code: 'https://github.com/huggingface/trl'
abstract: "With trl you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the transformers library by \U0001F917 Hugging Face. Therefore, pre-trained language models can be directly loaded via transformers. At this point, most decoder and encoder-decoder architectures are supported."
keywords:
  - rlhf
  - deep-learning
  - pytorch
  - transformers
license: Apache-2.0
version: 0.18

GitHub Events

Total
  • Push event: 2
  • Create event: 2
Last Year
  • Push event: 2
  • Create event: 2

Dependencies

docker/trl-latest-gpu/Dockerfile docker
  • continuumio/miniconda3 latest build
  • nvidia/cuda 12.2.2-devel-ubuntu22.04 build
docker/trl-source-gpu/Dockerfile docker
  • continuumio/miniconda3 latest build
  • nvidia/cuda 12.2.2-devel-ubuntu22.04 build
examples/research_projects/stack_llama_2/scripts/requirements.txt pypi
  • accelerate *
  • bitsandbytes *
  • datasets *
  • peft *
  • transformers *
  • trl *
  • wandb *
pyproject.toml pypi
requirements.txt pypi
  • accelerate >=1.4.0
  • datasets >=3.0.0
  • transformers >=4.51.0
setup.py pypi