Skip to main content

Train transformer language models with reinforcement learning.

Project description

TRL - Transformer Reinforcement Learning

Full stack library to fine-tune and align large language models.

License Documentation GitHub release

What is it?

The trl library is a full stack tool to fine-tune and align transformer language and diffusion models using methods such as Supervised Fine-tuning step (SFT), Reward Modeling (RM) and the Proximal Policy Optimization (PPO) as well as Direct Preference Optimization (DPO).

The library is built on top of the transformers library and thus allows to use any model architecture available there.

Highlights

  • Efficient and scalable:
    • accelerate is the backbone of trl which allows to scale model training from a single GPU to a large scale multi-node cluster with methods such as DDP and DeepSpeed.
    • PEFT is fully integrated and allows to train even the largest models on modest hardware with quantisation and methods such as LoRA or QLoRA.
    • unsloth is also integrated and allows to significantly speed up training with dedicated kernels.
  • CLI: With the CLI you can fine-tune and chat with LLMs without writing any code using a single command and a flexible config system.
  • Trainers: The Trainer classes are an abstraction to apply many fine-tuning methods with ease such as the SFTTrainer, FPOTrainer, RewardTrainer, PPOTrainer, CPOTrainer, and ORPOTrainer.
  • AutoModels: The AutoModelForCausalLMWithValueHead & AutoModelForSeq2SeqLMWithValueHead classes add an additional value head to the model which allows to train them with RL algorithms such as PPO.
  • Examples: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, StackLlama example, etc. following the examples.

Installation

Python package

Install the library with pip:

pip install trl

From source

If you want to use the latest features before an official release you can install from source:

pip install git+https://github.com/huggingface/trl.git

Repository

If you want to use the examples you can clone the repository with the following command:

git clone https://github.com/huggingface/trl.git

Command Line Interface (CLI)

You can use TRL Command Line Interface (CLI) to quickly get started with Supervised Fine-tuning (SFT), Direct Preference Optimization (DPO) and test your aligned model with the chat CLI:

SFT:

trl sft --model_name_or_path facebook/opt-125m --dataset_name imdb --output_dir opt-sft-imdb

DPO:

trl dpo --model_name_or_path facebook/opt-125m --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style --output_dir opt-sft-hh-rlhf 

Chat:

trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat

Read more about CLI in the relevant documentation section or use --help for more details.

How to use

For more flexibility and control over the training, you can use the dedicated trainer classes to fine-tune the model in Python.

SFTTrainer

This is a basic example of how to use the SFTTrainer from the library. The SFTTrainer is a light wrapper around the transformers Trainer to easily fine-tune language models or adapters on a custom dataset.

# imports
from datasets import load_dataset
from trl import SFTTrainer

# get dataset
dataset = load_dataset("imdb", split="train")

# get trainer
trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=512,
)

# train
trainer.train()

RewardTrainer

This is a basic example of how to use the RewardTrainer from the library. The RewardTrainer is a wrapper around the transformers Trainer to easily fine-tune reward models or adapters on a custom preference dataset.

# imports
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardTrainer

# load model and dataset - dataset needs to be in a specific format
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("gpt2")

...

# load trainer
trainer = RewardTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
)

# train
trainer.train()

PPOTrainer

This is a basic example of how to use the PPOTrainer from the library. Based on a query the language model creates a response which is then evaluated. The evaluation could be a human in the loop or another model's output.

# imports
import torch
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from trl.core import respond_to_batch

# get models
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
ref_model = create_reference_model(model)

tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

# initialize trainer
ppo_config = PPOConfig(batch_size=1, mini_batch_size=1)

# encode a query
query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt")

# get model response
response_tensor  = respond_to_batch(model, query_tensor)

# create a ppo trainer
ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer)

# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0)]

# train model for one step with ppo
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)

FPOTrainer

FPOTrainer is a trainer that uses Direct Preference Optimization algorithm. This is a basic example of how to use the FPOTrainer from the library. The FPOTrainer is a wrapper around the transformers Trainer to easily fine-tune reward models or adapters on a custom preference dataset.

# imports
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import FPOTrainer

# load model and dataset - dataset needs to be in a specific format
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

...

# load trainer
trainer = FPOTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
)

# train
trainer.train()

Development

If you want to contribute to trl or customizing it to your needs make sure to read the contribution guide and make sure you make a dev install:

git clone https://github.com/huggingface/trl.git
cd trl/
make dev

References

Proximal Policy Optimisation

The PPO implementation largely follows the structure introduced in the paper "Fine-Tuning Language Models from Human Preferences" by D. Ziegler et al. [paper, code].

Direct Preference Optimization

DPO is based on the original implementation of "Direct Preference Optimization: Your Language Model is Secretly a Reward Model" by E. Mitchell et al. [paper, code]

Citation

@misc{vonwerra2022trl,
  author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang},
  title = {TRL: Transformer Reinforcement Learning},
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/huggingface/trl}}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trl_fpo-0.0.12.tar.gz (231.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

trl_fpo-0.0.12-py3-none-any.whl (296.6 kB view details)

Uploaded Python 3

File details

Details for the file trl_fpo-0.0.12.tar.gz.

File metadata

  • Download URL: trl_fpo-0.0.12.tar.gz
  • Upload date:
  • Size: 231.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.9.13

File hashes

Hashes for trl_fpo-0.0.12.tar.gz
Algorithm Hash digest
SHA256 e268b3ca42f43465160e1228e91b7634cc9e66cabcb5e34e87c3a46027359c4b
MD5 cdad8f72315ca2bc98568981f0bb5f2f
BLAKE2b-256 99a66c2d69e9f2e20ca7778e218272c0ad65d916006eeeb1b2cafd1b0b590e20

See more details on using hashes here.

File details

Details for the file trl_fpo-0.0.12-py3-none-any.whl.

File metadata

  • Download URL: trl_fpo-0.0.12-py3-none-any.whl
  • Upload date:
  • Size: 296.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.9.13

File hashes

Hashes for trl_fpo-0.0.12-py3-none-any.whl
Algorithm Hash digest
SHA256 68ef4712852e94ad7057f512711cf3f7ab98df393c2f7a7392dfc468ca12d8e1
MD5 62cbe880c490e54784fa3dc0b084e92d
BLAKE2b-256 f62ff8a04cbde968a83e1f9d963e2e75b2bb24a90fbfdaf15466d7c8a7fd283c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page