Yüksek performanslı, güvenilir Transformer Reinforcement Learning kütüphanesi
Project description
FastTRL 🚀
TRL'nin tamamen yeniden yazılmış, daha hızlı ve güvenilir versiyonu.
Özellikler
| Özellik | Açıklama |
|---|---|
| ⚡ %30+ Hız | Sequence packing, fused optimizer, mixed precision |
| 🛡️ Sıfır Hata Toleransı | Tüm girişler doğrulanır, anlaşılır hata mesajları |
| 🎯 5 Trainer | SFT, PPO, DPO, Reward, GRPO |
| 🔌 Kolay API | TRL ile birebir uyumlu |
| 📦 PEFT Desteği | LoRA, QLoRA tam entegre |
| 🤗 HuggingFace Uyumlu | Tüm transformers modelleri çalışır |
Kurulum
pip install -e .
# PEFT + bitsandbytes ile:
pip install -e ".[all]"
Hızlı Başlangıç
1. SFT (Supervised Fine-Tuning)
from fasttrl import SFTTrainer, SFTConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3-8B")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3-8B")
dataset = load_dataset("trl-lib/ultrachat_200k", split="train_sft")
config = SFTConfig(
output_dir="./sft_output",
model_name="meta-llama/Llama-3-8B",
max_seq_length=2048,
num_train_epochs=3,
per_device_train_batch_size=4,
learning_rate=2e-5,
packing=True, # Sequence packing → daha hızlı eğitim
bf16=True, # bfloat16 mixed precision
neftune_noise_alpha=5, # NEFTune gürültü artırımı
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=config,
train_dataset=dataset,
)
trainer.train()
2. DPO (Direct Preference Optimization)
from fasttrl import DPOTrainer, DPOConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
model = AutoModelForCausalLM.from_pretrained("./sft_output/final_model")
tokenizer = AutoTokenizer.from_pretrained("./sft_output/final_model")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train_prefs")
config = DPOConfig(
output_dir="./dpo_output",
beta=0.1, # KL katsayısı (küçük = daha az kısıtlama)
loss_type="sigmoid", # Orijinal DPO
max_length=1024,
max_prompt_length=512,
learning_rate=5e-7,
bf16=True,
precompute_ref_log_probs=True, # Hız için önhesaplama
)
trainer = DPOTrainer(
model=model,
tokenizer=tokenizer,
args=config,
train_dataset=dataset,
)
trainer.train()
3. PPO (RLHF)
from fasttrl import PPOTrainer, PPOConfig
from fasttrl.models import AutoModelForCausalLMWithValueHead
from transformers import AutoTokenizer, pipeline
# Policy model (value head ile)
model = AutoModelForCausalLMWithValueHead.from_pretrained("./sft_output/final_model")
tokenizer = AutoTokenizer.from_pretrained("./sft_output/final_model")
# Ödül modeli
reward_pipeline = pipeline(
"text-classification",
model="./reward_output/final_model",
device=0,
)
def reward_fn(texts):
results = reward_pipeline(texts)
return [r["score"] for r in results]
config = PPOConfig(
output_dir="./ppo_output",
batch_size=64,
mini_batch_size=16,
ppo_epochs=4,
learning_rate=1e-5,
init_kl_coef=0.2,
target_kl=6.0,
adap_kl_ctrl=True,
)
trainer = PPOTrainer(
model=model,
tokenizer=tokenizer,
args=config,
)
# Eğitim döngüsü
for batch in your_dataloader:
# 1) Yanıt üret
responses = trainer.generate(batch["input_ids"])
# 2) Ödülleri hesapla
texts = [tokenizer.decode(r) for r in responses]
scores = [torch.tensor(s) for s in reward_fn(texts)]
# 3) PPO adımı
stats = trainer.step(batch["input_ids"], responses, scores)
print(stats)
4. GRPO (Group Relative Policy Optimization)
from fasttrl import GRPOTrainer, GRPOConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# Kural tabanlı ödül fonksiyonu (örn: matematiksel doğruluk)
def math_reward(completions, **kwargs):
rewards = []
for c in completions:
# Basit ödül: sayı içeriyorsa +1
rewards.append(1.0 if any(ch.isdigit() for ch in c) else 0.0)
return rewards
config = GRPOConfig(
output_dir="./grpo_output",
num_generations=8, # Her prompt için 8 yanıt
max_new_tokens=256,
beta=0.04,
epsilon=0.2,
learning_rate=1e-6,
)
trainer = GRPOTrainer(
model=model,
tokenizer=tokenizer,
args=config,
reward_funcs=[math_reward],
train_dataset=math_dataset,
)
trainer.train()
5. Reward Model Eğitimi
from fasttrl import RewardTrainer, RewardConfig
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train_prefs")
config = RewardConfig(
output_dir="./reward_output",
max_length=512,
learning_rate=1e-5,
margin=0.5, # Chosen-rejected arasında minimum fark
)
trainer = RewardTrainer(
model=model,
tokenizer=tokenizer,
args=config,
train_dataset=dataset,
)
trainer.train()
Desteklenen DPO Kayıp Fonksiyonları
| Tip | Kağıt |
|---|---|
sigmoid |
Orijinal DPO (Rafailov et al., 2023) |
ipo |
IPO (Azar et al., 2023) |
hinge |
Hinge DPO |
robust |
Robust DPO (Chowdhury et al., 2024) |
kto_pair |
KTO (Ethayarajh et al., 2024) |
bco_pair |
BCO |
apo_zero |
APO (Zhu et al., 2024) |
apo_down |
APO (Zhu et al., 2024) |
Hız Karşılaştırması
| Yöntem | TRL | FastTRL | Hız Artışı |
|---|---|---|---|
| Sequence Packing | Opsiyonel | Varsayılan | +30% |
| Fused AdamW | ✗ | ✓ | +18% |
| Ref logprob cache | ✗ | ✓ | DPO'da +40% |
| Grouped generation | ✗ | ✓ | GRPO'da +25% |
Konfigürasyon Kaydetme/Yükleme
# Kaydet
config.save("./my_config.json")
# Yükle
config = SFTConfig.load("./my_config.json")
Lisans
MIT License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
turklr-1.0.0.tar.gz
(39.2 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
turklr-1.0.0-py3-none-any.whl
(45.4 kB
view details)
File details
Details for the file turklr-1.0.0.tar.gz.
File metadata
- Download URL: turklr-1.0.0.tar.gz
- Upload date:
- Size: 39.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8bfe74b381c2b2642b1ec2ddf185a9fabbf26cebe1c910c438fcfae2ed99cc61
|
|
| MD5 |
42499cbe7333bdd6425a81ecede1cf2b
|
|
| BLAKE2b-256 |
835ec9d85dd861cef1d8ec0eb5a1cb0816613ec1bfd8ee60e90d5210c24fc569
|
File details
Details for the file turklr-1.0.0-py3-none-any.whl.
File metadata
- Download URL: turklr-1.0.0-py3-none-any.whl
- Upload date:
- Size: 45.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
59ef0b4a6e10c166e4150deaba4e4befbce9b256e38c8286523aa6e2f4cae66f
|
|
| MD5 |
8e397a841047c7ae852df446f7794934
|
|
| BLAKE2b-256 |
896116b9bd70e8ad3ae2f38b100e37156cb5cb0fc6e3e3b22fd8bc624731e432
|