Skip to main content

Yüksek performanslı, güvenilir Transformer Reinforcement Learning kütüphanesi

Project description

FastTRL 🚀

TRL'nin tamamen yeniden yazılmış, daha hızlı ve güvenilir versiyonu.


Özellikler

Özellik Açıklama
%30+ Hız Sequence packing, fused optimizer, mixed precision
🛡️ Sıfır Hata Toleransı Tüm girişler doğrulanır, anlaşılır hata mesajları
🎯 5 Trainer SFT, PPO, DPO, Reward, GRPO
🔌 Kolay API TRL ile birebir uyumlu
📦 PEFT Desteği LoRA, QLoRA tam entegre
🤗 HuggingFace Uyumlu Tüm transformers modelleri çalışır

Kurulum

pip install -e .
# PEFT + bitsandbytes ile:
pip install -e ".[all]"

Hızlı Başlangıç

1. SFT (Supervised Fine-Tuning)

from fasttrl import SFTTrainer, SFTConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3-8B")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3-8B")
dataset = load_dataset("trl-lib/ultrachat_200k", split="train_sft")

config = SFTConfig(
    output_dir="./sft_output",
    model_name="meta-llama/Llama-3-8B",
    max_seq_length=2048,
    num_train_epochs=3,
    per_device_train_batch_size=4,
    learning_rate=2e-5,
    packing=True,           # Sequence packing → daha hızlı eğitim
    bf16=True,              # bfloat16 mixed precision
    neftune_noise_alpha=5,  # NEFTune gürültü artırımı
)

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=config,
    train_dataset=dataset,
)
trainer.train()

2. DPO (Direct Preference Optimization)

from fasttrl import DPOTrainer, DPOConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

model = AutoModelForCausalLM.from_pretrained("./sft_output/final_model")
tokenizer = AutoTokenizer.from_pretrained("./sft_output/final_model")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train_prefs")

config = DPOConfig(
    output_dir="./dpo_output",
    beta=0.1,               # KL katsayısı (küçük = daha az kısıtlama)
    loss_type="sigmoid",    # Orijinal DPO
    max_length=1024,
    max_prompt_length=512,
    learning_rate=5e-7,
    bf16=True,
    precompute_ref_log_probs=True,  # Hız için önhesaplama
)

trainer = DPOTrainer(
    model=model,
    tokenizer=tokenizer,
    args=config,
    train_dataset=dataset,
)
trainer.train()

3. PPO (RLHF)

from fasttrl import PPOTrainer, PPOConfig
from fasttrl.models import AutoModelForCausalLMWithValueHead
from transformers import AutoTokenizer, pipeline

# Policy model (value head ile)
model = AutoModelForCausalLMWithValueHead.from_pretrained("./sft_output/final_model")
tokenizer = AutoTokenizer.from_pretrained("./sft_output/final_model")

# Ödül modeli
reward_pipeline = pipeline(
    "text-classification",
    model="./reward_output/final_model",
    device=0,
)

def reward_fn(texts):
    results = reward_pipeline(texts)
    return [r["score"] for r in results]

config = PPOConfig(
    output_dir="./ppo_output",
    batch_size=64,
    mini_batch_size=16,
    ppo_epochs=4,
    learning_rate=1e-5,
    init_kl_coef=0.2,
    target_kl=6.0,
    adap_kl_ctrl=True,
)

trainer = PPOTrainer(
    model=model,
    tokenizer=tokenizer,
    args=config,
)

# Eğitim döngüsü
for batch in your_dataloader:
    # 1) Yanıt üret
    responses = trainer.generate(batch["input_ids"])

    # 2) Ödülleri hesapla
    texts = [tokenizer.decode(r) for r in responses]
    scores = [torch.tensor(s) for s in reward_fn(texts)]

    # 3) PPO adımı
    stats = trainer.step(batch["input_ids"], responses, scores)
    print(stats)

4. GRPO (Group Relative Policy Optimization)

from fasttrl import GRPOTrainer, GRPOConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Kural tabanlı ödül fonksiyonu (örn: matematiksel doğruluk)
def math_reward(completions, **kwargs):
    rewards = []
    for c in completions:
        # Basit ödül: sayı içeriyorsa +1
        rewards.append(1.0 if any(ch.isdigit() for ch in c) else 0.0)
    return rewards

config = GRPOConfig(
    output_dir="./grpo_output",
    num_generations=8,      # Her prompt için 8 yanıt
    max_new_tokens=256,
    beta=0.04,
    epsilon=0.2,
    learning_rate=1e-6,
)

trainer = GRPOTrainer(
    model=model,
    tokenizer=tokenizer,
    args=config,
    reward_funcs=[math_reward],
    train_dataset=math_dataset,
)
trainer.train()

5. Reward Model Eğitimi

from fasttrl import RewardTrainer, RewardConfig
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset

model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train_prefs")

config = RewardConfig(
    output_dir="./reward_output",
    max_length=512,
    learning_rate=1e-5,
    margin=0.5,  # Chosen-rejected arasında minimum fark
)

trainer = RewardTrainer(
    model=model,
    tokenizer=tokenizer,
    args=config,
    train_dataset=dataset,
)
trainer.train()

Desteklenen DPO Kayıp Fonksiyonları

Tip Kağıt
sigmoid Orijinal DPO (Rafailov et al., 2023)
ipo IPO (Azar et al., 2023)
hinge Hinge DPO
robust Robust DPO (Chowdhury et al., 2024)
kto_pair KTO (Ethayarajh et al., 2024)
bco_pair BCO
apo_zero APO (Zhu et al., 2024)
apo_down APO (Zhu et al., 2024)

Hız Karşılaştırması

Yöntem TRL FastTRL Hız Artışı
Sequence Packing Opsiyonel Varsayılan +30%
Fused AdamW +18%
Ref logprob cache DPO'da +40%
Grouped generation GRPO'da +25%

Konfigürasyon Kaydetme/Yükleme

# Kaydet
config.save("./my_config.json")

# Yükle
config = SFTConfig.load("./my_config.json")

Lisans

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fasttrl-1.0.0.tar.gz (39.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fasttrl-1.0.0-py3-none-any.whl (45.4 kB view details)

Uploaded Python 3

File details

Details for the file fasttrl-1.0.0.tar.gz.

File metadata

  • Download URL: fasttrl-1.0.0.tar.gz
  • Upload date:
  • Size: 39.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.7

File hashes

Hashes for fasttrl-1.0.0.tar.gz
Algorithm Hash digest
SHA256 7ed7a08ccc831bf1d0fa2457c9def8efdde7327190e4dae82a92d17ce1c4228f
MD5 43f1ff89445825c3c9fd30c7044b43b2
BLAKE2b-256 6a76144f79df45a92f4fbe679d4b60c1cbba0acb6c1f2bb86ce57b457cfa0df8

See more details on using hashes here.

File details

Details for the file fasttrl-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: fasttrl-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 45.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.7

File hashes

Hashes for fasttrl-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 6d74e4e37db9aa80c76b6c013c51f0b3a1393f9f047561e38ac657e0da3aebd8
MD5 c10ad4467d9bacf36d7eadfb601d7416
BLAKE2b-256 955b662c881098254ab3fee46e9786322df304914fc0b911dd61dede2da07073

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page