Skip to main content

Train transformer language models with reinforcement learning.

Project description

TRL - Transformer Reinforcement Learning

TRL Banner


A comprehensive library to post-train foundation models

License Documentation GitHub release

Overview

TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO). Built on top of the 🤗 Transformers ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups.

Highlights

  • Efficient and scalable:

    • Leverages 🤗 Accelerate to scale from single GPU to multi-node clusters using methods like DDP and DeepSpeed.
    • Full integration with PEFT enables training on large models with modest hardware via quantization and LoRA/QLoRA.
    • Integrates Unsloth for accelerating training using optimized kernels.
  • Command Line Interface (CLI): A simple interface lets you fine-tune and interact with models without needing to write code.

  • Trainers: Various fine-tuning methods are easily accessible via trainers like SFTTrainer, DPOTrainer, RewardTrainer, ORPOTrainer and more.

  • AutoModels: Use pre-defined model classes like AutoModelForCausalLMWithValueHead to simplify reinforcement learning (RL) with LLMs.

Installation

Python Package

Install the library using pip:

pip install trl

From source

If you want to use the latest features before an official release, you can install TRL from source:

pip install git+https://github.com/huggingface/trl.git

Repository

If you want to use the examples you can clone the repository with the following command:

git clone https://github.com/huggingface/trl.git

Command Line Interface (CLI)

You can use the TRL Command Line Interface (CLI) to quickly get started with Supervised Fine-tuning (SFT) and Direct Preference Optimization (DPO), or vibe check your model with the chat CLI:

SFT:

trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
    --dataset_name trl-lib/Capybara \
    --output_dir Qwen2.5-0.5B-SFT

DPO:

trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
    --dataset_name argilla/Capybara-Preferences \
    --output_dir Qwen2.5-0.5B-DPO 

Chat:

trl chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct

Read more about CLI in the relevant documentation section or use --help for more details.

How to use

For more flexibility and control over training, TRL provides dedicated trainer classes to post-train language models or PEFT adapters on a custom dataset. Each trainer in TRL is a light wrapper around the 🤗 Transformers trainer and natively supports distributed training methods like DDP, DeepSpeed ZeRO, and FSDP.

SFTTrainer

Here is a basic example of how to use the SFTTrainer:

from trl import SFTConfig, SFTTrainer
from datasets import load_dataset

dataset = load_dataset("trl-lib/Capybara", split="train")

training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT")
trainer = SFTTrainer(
    args=training_args,
    model="Qwen/Qwen2.5-0.5B",
    train_dataset=dataset,
)
trainer.train()

RewardTrainer

Here is a basic example of how to use the RewardTrainer:

from trl import RewardConfig, RewardTrainer
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForSequenceClassification.from_pretrained(
    "Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
)
model.config.pad_token_id = tokenizer.pad_token_id

dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")

training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2)
trainer = RewardTrainer(
    args=training_args,
    model=model,
    processing_class=tokenizer,
    train_dataset=dataset,
)
trainer.train()

RLOOTrainer

RLOOTrainer implements a REINFORCE-style optimization for RLHF that is more performant and memory-efficient than PPO. Here is a basic example of how to use the RLOOTrainer:

from trl import RLOOConfig, RLOOTrainer, apply_chat_template
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
)

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
reward_model = AutoModelForSequenceClassification.from_pretrained(
    "Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
)
ref_policy = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
policy = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")

dataset = load_dataset("trl-lib/ultrafeedback-prompt")
dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
dataset = dataset.map(lambda x: tokenizer(x["prompt"]), remove_columns="prompt")

training_args = RLOOConfig(output_dir="Qwen2.5-0.5B-RL")
trainer = RLOOTrainer(
    config=training_args,
    processing_class=tokenizer,
    policy=policy,
    ref_policy=ref_policy,
    reward_model=reward_model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
)
trainer.train()

DPOTrainer

DPOTrainer implements the popular Direct Preference Optimization (DPO) algorithm that was used to post-train Llama 3 and many other models. Here is a basic example of how to use the DPOTrainer:

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
trainer = DPOTrainer(model=model, args=training_args, train_dataset=dataset, processing_class=tokenizer)
trainer.train()

Development

If you want to contribute to trl or customize it to your needs make sure to read the contribution guide and make sure you make a dev install:

git clone https://github.com/huggingface/trl.git
cd trl/
make dev

Citation

@misc{vonwerra2022trl,
  author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec},
  title = {TRL: Transformer Reinforcement Learning},
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/huggingface/trl}}
}

License

This repository's source code is available under the Apache-2.0 License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trl-0.12.2.tar.gz (488.8 kB view details)

Uploaded Source

Built Distribution

trl-0.12.2-py3-none-any.whl (365.7 kB view details)

Uploaded Python 3

File details

Details for the file trl-0.12.2.tar.gz.

File metadata

  • Download URL: trl-0.12.2.tar.gz
  • Upload date:
  • Size: 488.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.11

File hashes

Hashes for trl-0.12.2.tar.gz
Algorithm Hash digest
SHA256 bfe8c7d16c8bc0a82a81cc92439286a7857c1a55bb51d1f5b2f210e1897fc9fa
MD5 c4f942c7a7afba481761ba7c6aa86b3f
BLAKE2b-256 b6591bffafb8c7dc1c9191a2db90372cf7769112119222296b4ea070c4fc1766

See more details on using hashes here.

File details

Details for the file trl-0.12.2-py3-none-any.whl.

File metadata

  • Download URL: trl-0.12.2-py3-none-any.whl
  • Upload date:
  • Size: 365.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.11

File hashes

Hashes for trl-0.12.2-py3-none-any.whl
Algorithm Hash digest
SHA256 2673838f199cc8b13f58f5e2f4bfdc95aa959e0e14a8c827dce71275954afaee
MD5 22dca1109b8d6d0780e97cfddd4d46c6
BLAKE2b-256 13e563dcbaa86d1336aeda0a831bbe657aaadfe59f7a906aa5eb8279be4b285e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page