Skip to main content

TextRL - reinforcement learning for text generation, built on HuggingFace TRL.

Project description

TextRL: Reinforcement Learning for Text Generation

PyPI Last Commit

TextRL is a thin, opinionated layer on top of HuggingFace TRL that makes modern text-generation RL ergonomic: one dataclass for configuration, one trainer class per algorithm family, callable reward functions, and first-class PEFT / accelerate / vLLM support.

v1.0 breaking change. The legacy PFRL/gym API (TextRLEnv, TextRLActor, train_agent_with_evaluation) is gone. See docs/migration.md.

Why TextRL vs. raw TRL?

TextRL is a thin wrapper, not a replacement. Use it when the ergonomics are worth more than the indirection; drop down to raw TRL when they aren't.

What TextRL adds on top of TRL

  • One config, one trainer per family. Pick an algorithm by string — algo="ipo" — and TextRLConfig dispatches to the right one of TRL's 15+ config classes. No need to remember that IPO lives inside DPOTrainer with loss_type="ipo", or that REINFORCE++ is RLOOTrainer with rloo_k=1.
  • load_model one-liner. PEFT + 4/8-bit quantization + reference model + tokenizer padding defaults, handled together.
  • Reward composition. @reward_fn decorator, compose(f1, f2, weights=...) for weighted sums, ClassifierReward to wrap any HuggingFace pipeline as a reward.
  • Schema validation up front. Dataset shape is checked before training starts, not 500 steps in.
  • YAML-driven CLI. textrl-train --config cfg.yaml — configuration-as-experiment, good for sweeps and reproducibility.
  • Migration hints for removed algos. Ask for ppo / orpo / simpo / cpo and you get a pointer to the modern replacement instead of a cryptic ImportError.
  • PEFT merge utility. textrl-merge produces a standalone HF checkpoint from a LoRA adapter.

Honest trade-offs

  • It's a thin layer. Brand-new TRL features land upstream first; TextRL tracks them.
  • Advanced customization means reaching through the .trl_trainer escape hatch anyway.
  • Smaller community, less battle-testing than TRL itself.

When to use TextRL

  • Comparing multiple algorithms with minimal boilerplate changes.
  • YAML-driven experiments, reward composition, or wrapping classifiers as rewards.
  • You want sensible PEFT/QLoRA/ref-model defaults without reading three docs pages.

When to use raw TRL

  • Single algorithm, heavy customization, or you need a feature that hasn't landed in TextRL yet.
  • You're already fluent in the TRL API and the wrapper would just be indirection.

Supported algorithms

Family Algorithms TRL trainer
Online GRPO, RLOO, REINFORCE++ GRPOTrainer, RLOOTrainer
Preference (pairwise) DPO, IPO, Hinge, APO (zero/down), BCO-pair, NCA-pair, Robust-DPO, AOT, DiscoPOP, SPPO-hard, EXO-pair DPOTrainer (unified loss_type)
Preference (binary) KTO KTOTrainer
Reward model Pairwise reward training RewardTrainer

Removed in TRL 0.29+ and therefore not supported: PPO, OnlineDPO, ORPO, CPO, SimPO, BCO (binary). TextRL raises with a migration hint if you ask for them.

Install

pip install textrl                         # core
pip install 'textrl[quant]'                # + bitsandbytes (QLoRA)
pip install 'textrl[vllm]'                 # + vLLM rollout
pip install 'textrl[quant,vllm,rewards]'   # kitchen sink

Quickstart

GRPO with a callable reward

from textrl import OnlineTrainer, TextRLConfig, load_model, reward_fn
from textrl.data import from_list

@reward_fn
def length_reward(prompts, completions, **_):
    return [-abs(len(c) - 64) / 64 for c in completions]

model, tok, _ = load_model("Qwen/Qwen2.5-0.5B", peft={"type": "lora", "r": 16})

cfg = TextRLConfig(
    algo="grpo",
    output_dir="out/grpo",
    num_generations=8,
    beta=0.04,
    learning_rate=5e-6,
    bf16=True,
)

trainer = OnlineTrainer(
    model=model,
    tokenizer=tok,
    reward=length_reward,
    train_dataset=from_list(["Write a short poem.", "Explain gradient descent."] * 32),
    config=cfg,
)
trainer.train()

DPO with a preference dataset

from textrl import PreferenceTrainer, TextRLConfig, load_model
from textrl.data import from_hub

model, tok, ref = load_model("meta-llama/Llama-3.2-1B", peft={"type": "lora", "r": 16}, quantization="4bit")

cfg = TextRLConfig(algo="dpo", output_dir="out/dpo", beta=0.1, bf16=True)

trainer = PreferenceTrainer(
    model=model,
    ref_model=ref,
    tokenizer=tok,
    train_dataset=from_hub("trl-lib/ultrafeedback_binarized"),
    config=cfg,
)
trainer.train()

KTO with binary feedback

from textrl import PreferenceTrainer, TextRLConfig, load_model

cfg = TextRLConfig(algo="kto", output_dir="out/kto", beta=0.1, bf16=True)
model, tok, ref = load_model("Qwen/Qwen2.5-0.5B")
trainer = PreferenceTrainer(
    model=model, ref_model=ref, tokenizer=tok,
    train_dataset=my_kto_dataset,   # needs prompt/completion/label
    config=cfg,
)
trainer.train()

RLOO with a trained reward model

from textrl import OnlineTrainer, RewardModelTrainer, TextRLConfig, load_model

rm_cfg = TextRLConfig(algo="reward_model", output_dir="out/rm", bf16=True)
rm_model, tok, _ = load_model("distilbert/distilbert-base-uncased", load_ref=False)
RewardModelTrainer(model=rm_model, tokenizer=tok, train_dataset=rm_ds, config=rm_cfg).train()

model, tok, ref = load_model("Qwen/Qwen2.5-0.5B")
cfg = TextRLConfig(algo="rloo", output_dir="out/rloo", bf16=True)
OnlineTrainer(model=model, ref_model=ref, tokenizer=tok,
              reward=rm_model, train_dataset=prompts, config=cfg).train()

Reward functions

Rewards are plain callables with the signature TRL expects:

def reward(prompts: list[str], completions: list[str], **columns) -> list[float]: ...

Decorate with @reward_fn (coerces into a RewardFn protocol object), or subclass BaseReward for stateful rewards (e.g. a loaded classifier). Compose multiple rewards with compose(*fns, weights=...):

from textrl.rewards import compose, length_penalty, reward_fn

@reward_fn
def semantic_match(prompts, completions, **_):
    return [...]

reward = compose(semantic_match, length_penalty, weights=[1.0, 0.1])

ClassifierReward wraps any HuggingFace pipeline:

from transformers import pipeline
from textrl.rewards import ClassifierReward

sentiment = pipeline("sentiment-analysis", model="cardiffnlp/twitter-roberta-base-sentiment")
reward = ClassifierReward(sentiment, target_label="LABEL_2")  # positive

Data formats

Mode Required columns Used by
Prompt-only prompt (or messages) GRPO, RLOO, REINFORCE++
Pairwise preference prompt, chosen, rejected DPO, IPO, Hinge, APO, BCO-pair, etc.
Binary feedback prompt, completion, label: bool KTO
Reward model chosen, rejected RewardModelTrainer

Use textrl.data.from_list, from_jsonl, or from_hub to construct datasets, or pass any datasets.Dataset directly.

Model loading

load_model returns (policy, tokenizer, ref_model_or_None):

from textrl import load_model

model, tok, ref = load_model(
    "meta-llama/Llama-3.2-1B",
    peft={"type": "lora", "r": 16, "alpha": 32, "target_modules": "all-linear"},
    quantization="4bit",          # nf4 QLoRA
    torch_dtype="bfloat16",
    attn_implementation="flash_attention_2",
    load_ref=True,                 # False for GRPO/RLOO to save memory
)

When peft is set, ref_model is None — TRL disables adapters for the reference forward pass.

Distributed training

Launch via accelerate. TextRL adds no scaffolding of its own:

accelerate launch -m textrl.cli train --config configs/grpo.yaml

TextRLConfig.distributed={"strategy": "deepspeed", "zero_stage": 3} is forwarded to TRL via the extra field.

vLLM rollout (GRPO only)

cfg = TextRLConfig(
    algo="grpo", output_dir="out",
    extra={"use_vllm": True, "vllm_gpu_memory_utilization": 0.6},
)

Or use the helper textrl.rollout.vllm.vllm_config(...) to build the extras dict.

CLI

Command Purpose
textrl-train --config cfg.yaml YAML-driven training
textrl-merge --adapter DIR --output DIR Merge a PEFT adapter into a standalone HF checkpoint
textrl-eval --model PATH --dataset SPEC --reward module:fn Rollout + reward stats (no training)
textrl-dump Deprecated alias for textrl-merge

Example YAML:

algo: grpo
output_dir: out/grpo
learning_rate: 5e-6
num_train_epochs: 1
num_generations: 8
beta: 0.04
bf16: true

model:
  name: Qwen/Qwen2.5-0.5B

dataset:
  hub: trl-lib/tldr
  split: train[:1%]

reward: my_rewards:length_reward

Development

pip install -e '.[dev,quant,rewards]'
PYTEST_DISABLE_PLUGIN_AUTOLOAD=1 pytest tests/unit
pytest -m smoke tests/smoke   # needs a small model to be downloadable

License

Apache 2.0.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

textrl-1.0.0.tar.gz (405.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

textrl-1.0.0-py3-none-any.whl (26.2 kB view details)

Uploaded Python 3

File details

Details for the file textrl-1.0.0.tar.gz.

File metadata

  • Download URL: textrl-1.0.0.tar.gz
  • Upload date:
  • Size: 405.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.9

File hashes

Hashes for textrl-1.0.0.tar.gz
Algorithm Hash digest
SHA256 22911c68e64dcfe2d9dac0bf696d18bbd7bea0466d9816afa1c207db6aaa9e6d
MD5 e9a63b893cd6a093c7d21adf9995232f
BLAKE2b-256 790ca4119b7fa3577fc00a7ba4ab050a6dadb3f675dfdaa0594b487999e4ceed

See more details on using hashes here.

File details

Details for the file textrl-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: textrl-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 26.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.9

File hashes

Hashes for textrl-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 227206190955f0f1256b5df8d2f58dc78a73c1be425315fb90a5536bddb88c56
MD5 29692e600742ce3d47e78517f63fd6e3
BLAKE2b-256 7f4d703f9e6afc68e6af94c7ac50eb73de67c2df4047b8619a70db290a063d81

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page