Skip to main content

Yüksek performanslı, güvenilir Transformer Reinforcement Learning kütüphanesi

Project description

FastTRL 🚀

TRL'nin tamamen yeniden yazılmış, daha hızlı ve güvenilir versiyonu.


Özellikler

Özellik Açıklama
%30+ Hız Sequence packing, fused optimizer, mixed precision
🛡️ Sıfır Hata Toleransı Tüm girişler doğrulanır, anlaşılır hata mesajları
🎯 5 Trainer SFT, PPO, DPO, Reward, GRPO
🔌 Kolay API TRL ile birebir uyumlu
📦 PEFT Desteği LoRA, QLoRA tam entegre
🤗 HuggingFace Uyumlu Tüm transformers modelleri çalışır

Kurulum

pip install -e .
# PEFT + bitsandbytes ile:
pip install -e ".[all]"

Hızlı Başlangıç

1. SFT (Supervised Fine-Tuning)

from fasttrl import SFTTrainer, SFTConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3-8B")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3-8B")
dataset = load_dataset("trl-lib/ultrachat_200k", split="train_sft")

config = SFTConfig(
    output_dir="./sft_output",
    model_name="meta-llama/Llama-3-8B",
    max_seq_length=2048,
    num_train_epochs=3,
    per_device_train_batch_size=4,
    learning_rate=2e-5,
    packing=True,           # Sequence packing → daha hızlı eğitim
    bf16=True,              # bfloat16 mixed precision
    neftune_noise_alpha=5,  # NEFTune gürültü artırımı
)

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=config,
    train_dataset=dataset,
)
trainer.train()

2. DPO (Direct Preference Optimization)

from fasttrl import DPOTrainer, DPOConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

model = AutoModelForCausalLM.from_pretrained("./sft_output/final_model")
tokenizer = AutoTokenizer.from_pretrained("./sft_output/final_model")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train_prefs")

config = DPOConfig(
    output_dir="./dpo_output",
    beta=0.1,               # KL katsayısı (küçük = daha az kısıtlama)
    loss_type="sigmoid",    # Orijinal DPO
    max_length=1024,
    max_prompt_length=512,
    learning_rate=5e-7,
    bf16=True,
    precompute_ref_log_probs=True,  # Hız için önhesaplama
)

trainer = DPOTrainer(
    model=model,
    tokenizer=tokenizer,
    args=config,
    train_dataset=dataset,
)
trainer.train()

3. PPO (RLHF)

from fasttrl import PPOTrainer, PPOConfig
from fasttrl.models import AutoModelForCausalLMWithValueHead
from transformers import AutoTokenizer, pipeline

# Policy model (value head ile)
model = AutoModelForCausalLMWithValueHead.from_pretrained("./sft_output/final_model")
tokenizer = AutoTokenizer.from_pretrained("./sft_output/final_model")

# Ödül modeli
reward_pipeline = pipeline(
    "text-classification",
    model="./reward_output/final_model",
    device=0,
)

def reward_fn(texts):
    results = reward_pipeline(texts)
    return [r["score"] for r in results]

config = PPOConfig(
    output_dir="./ppo_output",
    batch_size=64,
    mini_batch_size=16,
    ppo_epochs=4,
    learning_rate=1e-5,
    init_kl_coef=0.2,
    target_kl=6.0,
    adap_kl_ctrl=True,
)

trainer = PPOTrainer(
    model=model,
    tokenizer=tokenizer,
    args=config,
)

# Eğitim döngüsü
for batch in your_dataloader:
    # 1) Yanıt üret
    responses = trainer.generate(batch["input_ids"])

    # 2) Ödülleri hesapla
    texts = [tokenizer.decode(r) for r in responses]
    scores = [torch.tensor(s) for s in reward_fn(texts)]

    # 3) PPO adımı
    stats = trainer.step(batch["input_ids"], responses, scores)
    print(stats)

4. GRPO (Group Relative Policy Optimization)

from fasttrl import GRPOTrainer, GRPOConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Kural tabanlı ödül fonksiyonu (örn: matematiksel doğruluk)
def math_reward(completions, **kwargs):
    rewards = []
    for c in completions:
        # Basit ödül: sayı içeriyorsa +1
        rewards.append(1.0 if any(ch.isdigit() for ch in c) else 0.0)
    return rewards

config = GRPOConfig(
    output_dir="./grpo_output",
    num_generations=8,      # Her prompt için 8 yanıt
    max_new_tokens=256,
    beta=0.04,
    epsilon=0.2,
    learning_rate=1e-6,
)

trainer = GRPOTrainer(
    model=model,
    tokenizer=tokenizer,
    args=config,
    reward_funcs=[math_reward],
    train_dataset=math_dataset,
)
trainer.train()

5. Reward Model Eğitimi

from fasttrl import RewardTrainer, RewardConfig
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset

model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train_prefs")

config = RewardConfig(
    output_dir="./reward_output",
    max_length=512,
    learning_rate=1e-5,
    margin=0.5,  # Chosen-rejected arasında minimum fark
)

trainer = RewardTrainer(
    model=model,
    tokenizer=tokenizer,
    args=config,
    train_dataset=dataset,
)
trainer.train()

Desteklenen DPO Kayıp Fonksiyonları

Tip Kağıt
sigmoid Orijinal DPO (Rafailov et al., 2023)
ipo IPO (Azar et al., 2023)
hinge Hinge DPO
robust Robust DPO (Chowdhury et al., 2024)
kto_pair KTO (Ethayarajh et al., 2024)
bco_pair BCO
apo_zero APO (Zhu et al., 2024)
apo_down APO (Zhu et al., 2024)

Hız Karşılaştırması

Yöntem TRL FastTRL Hız Artışı
Sequence Packing Opsiyonel Varsayılan +30%
Fused AdamW +18%
Ref logprob cache DPO'da +40%
Grouped generation GRPO'da +25%

Konfigürasyon Kaydetme/Yükleme

# Kaydet
config.save("./my_config.json")

# Yükle
config = SFTConfig.load("./my_config.json")

Lisans

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

turklr-1.0.0.tar.gz (39.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

turklr-1.0.0-py3-none-any.whl (45.4 kB view details)

Uploaded Python 3

File details

Details for the file turklr-1.0.0.tar.gz.

File metadata

  • Download URL: turklr-1.0.0.tar.gz
  • Upload date:
  • Size: 39.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.7

File hashes

Hashes for turklr-1.0.0.tar.gz
Algorithm Hash digest
SHA256 8bfe74b381c2b2642b1ec2ddf185a9fabbf26cebe1c910c438fcfae2ed99cc61
MD5 42499cbe7333bdd6425a81ecede1cf2b
BLAKE2b-256 835ec9d85dd861cef1d8ec0eb5a1cb0816613ec1bfd8ee60e90d5210c24fc569

See more details on using hashes here.

File details

Details for the file turklr-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: turklr-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 45.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.7

File hashes

Hashes for turklr-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 59ef0b4a6e10c166e4150deaba4e4befbce9b256e38c8286523aa6e2f4cae66f
MD5 8e397a841047c7ae852df446f7794934
BLAKE2b-256 896116b9bd70e8ad3ae2f38b100e37156cb5cb0fc6e3e3b22fd8bc624731e432

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page