Skip to main content

Train transformer language models with reinforcement learning.

Project description

TRL - Transformer Reinforcement Learning

TRL Banner


A comprehensive library to post-train foundation models

License Documentation GitHub release Hugging Face Hub

Overview

TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO). Built on top of the 🤗 Transformers ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups.

Highlights

  • Trainers: Various fine-tuning methods are easily accessible via trainers like SFTTrainer, GRPOTrainer, DPOTrainer, RewardTrainer and more.

  • Efficient and scalable:

    • Leverages 🤗 Accelerate to scale from single GPU to multi-node clusters using methods like DDP and DeepSpeed.
    • Full integration with 🤗 PEFT enables training on large models with modest hardware via quantization and LoRA/QLoRA.
    • Integrates 🦥 Unsloth for accelerating training using optimized kernels.
  • Command Line Interface (CLI): A simple interface lets you fine-tune with models without needing to write code.

Installation

Python Package

Install the library using pip:

pip install trl

From source

If you want to use the latest features before an official release, you can install TRL from source:

pip install git+https://github.com/huggingface/trl.git

Repository

If you want to use the examples you can clone the repository with the following command:

git clone https://github.com/huggingface/trl.git

Quick Start

For more flexibility and control over training, TRL provides dedicated trainer classes to post-train language models or PEFT adapters on a custom dataset. Each trainer in TRL is a light wrapper around the 🤗 Transformers trainer and natively supports distributed training methods like DDP, DeepSpeed ZeRO, and FSDP.

SFTTrainer

Here is a basic example of how to use the SFTTrainer:

from trl import SFTTrainer
from datasets import load_dataset

dataset = load_dataset("trl-lib/Capybara", split="train")

trainer = SFTTrainer(
    model="Qwen/Qwen2.5-0.5B",
    train_dataset=dataset,
)
trainer.train()

GRPOTrainer

GRPOTrainer implements the Group Relative Policy Optimization (GRPO) algorithm that is more memory-efficient than PPO and was used to train Deepseek AI's R1.

from datasets import load_dataset
from trl import GRPOTrainer

dataset = load_dataset("trl-lib/tldr", split="train")

# Dummy reward function: count the number of unique characters in the completions
def reward_num_unique_chars(completions, **kwargs):
    return [len(set(c)) for c in completions]

trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=reward_num_unique_chars,
    train_dataset=dataset,
)
trainer.train()

DPOTrainer

DPOTrainer implements the popular Direct Preference Optimization (DPO) algorithm that was used to post-train Llama 3 and many other models. Here is a basic example of how to use the DPOTrainer:

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
trainer = DPOTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    processing_class=tokenizer
)
trainer.train()

RewardTrainer

Here is a basic example of how to use the RewardTrainer:

from trl import RewardConfig, RewardTrainer
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForSequenceClassification.from_pretrained(
    "Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
)
model.config.pad_token_id = tokenizer.pad_token_id

dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")

training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2)
trainer = RewardTrainer(
    args=training_args,
    model=model,
    processing_class=tokenizer,
    train_dataset=dataset,
)
trainer.train()

Command Line Interface (CLI)

You can use the TRL Command Line Interface (CLI) to quickly get started with post-training methods like Supervised Fine-Tuning (SFT) or Direct Preference Optimization (DPO):

SFT:

trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
    --dataset_name trl-lib/Capybara \
    --output_dir Qwen2.5-0.5B-SFT

DPO:

trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
    --dataset_name argilla/Capybara-Preferences \
    --output_dir Qwen2.5-0.5B-DPO 

Read more about CLI in the relevant documentation section or use --help for more details.

Development

If you want to contribute to trl or customize it to your needs make sure to read the contribution guide and make sure you make a dev install:

git clone https://github.com/huggingface/trl.git
cd trl/
pip install -e .[dev]

Citation

@misc{vonwerra2022trl,
  author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec},
  title = {TRL: Transformer Reinforcement Learning},
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/huggingface/trl}}
}

License

This repository's source code is available under the Apache-2.0 License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mod_trl-0.20.0.dev0.tar.gz (630.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mod_trl-0.20.0.dev0-py3-none-any.whl (379.7 kB view details)

Uploaded Python 3

File details

Details for the file mod_trl-0.20.0.dev0.tar.gz.

File metadata

  • Download URL: mod_trl-0.20.0.dev0.tar.gz
  • Upload date:
  • Size: 630.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.23

File hashes

Hashes for mod_trl-0.20.0.dev0.tar.gz
Algorithm Hash digest
SHA256 4d03af34b7a4e848e508a27af6e8d2ab620c16b0f29b0d7f079a294d613a64d3
MD5 7292d833a72b824711b4a965bb1fa1fd
BLAKE2b-256 f533f8c4a5ff6a2ade1169a58ac3f95c0ee7581f4fe72529b030fdb8bf78aa8f

See more details on using hashes here.

File details

Details for the file mod_trl-0.20.0.dev0-py3-none-any.whl.

File metadata

  • Download URL: mod_trl-0.20.0.dev0-py3-none-any.whl
  • Upload date:
  • Size: 379.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.23

File hashes

Hashes for mod_trl-0.20.0.dev0-py3-none-any.whl
Algorithm Hash digest
SHA256 3128102b0c6cfd442c0c6a5284ab816760ea0105e780a9d19d77fe32b2e84ce0
MD5 50bbded952fc4fd82655402f4ea042e1
BLAKE2b-256 1d2c1b47c8ee5732b26ccdd625210516749f8c555aaf532614579a4ad44cea91

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page