PyTorch-based trainer for Agent trajectory datasets — SFT, DPO, GRPO
Project description
knowlyr-trainer
纯 PyTorch Agent 轨迹训练工具 — SFT / DPO / GRPO,无缝对接 knowlyr-hub 导出的数据集。
Agent 训练增强: 多轮对话格式、观察 token 遮蔽、步骤级 reward 加权、长轨迹分块、课程学习。
安装
pip install knowlyr-trainer
# 可选
pip install knowlyr-trainer[peft] # LoRA 微调
pip install knowlyr-trainer[wandb] # wandb 日志
pip install knowlyr-trainer[all] # 全部
快速开始
CLI
# SFT 训练
knowlyr-trainer sft --train-file sft.jsonl --model Qwen/Qwen2.5-Coder-7B
# DPO 偏好学习
knowlyr-trainer dpo --train-file dpo.jsonl --model ./output/sft/final --beta 0.1
# GRPO 组内相对策略优化
knowlyr-trainer grpo --train-file grpo.jsonl --model ./output/sft/final
# 模型评估
knowlyr-trainer eval --model ./output/sft/final --eval-file eval.jsonl
YAML 配置
knowlyr-trainer sft --config train_config.yaml
# train_config.yaml
model_name_or_path: Qwen/Qwen2.5-Coder-7B
train_file: sft_data.jsonl
output_dir: ./output/sft
num_epochs: 3
batch_size: 4
learning_rate: 2e-5
max_length: 4096
bf16: true
use_lora: true
agent_format: true # 启用 Agent 多轮格式
mask_observations: true # 遮蔽观察 token
step_weighted_loss: true # 步骤级 reward 加权
curriculum: true # 课程学习
Python API
from agenttrainer import SFTConfig
from agenttrainer.trainers.sft import SFTTrainer
config = SFTConfig(
model_name_or_path="Qwen/Qwen2.5-Coder-7B",
train_file="sft_data.jsonl",
output_dir="./output",
agent_format=True,
mask_observations=True,
)
trainer = SFTTrainer(config)
trainer.train()
数据格式
无缝对接 knowlyr-hub export 导出的 JSONL:
knowlyr-hub export --format sft -t trajectories.jsonl -o sft_train.jsonl
knowlyr-hub export --format dpo -t trajectories.jsonl -p preferences.jsonl -o dpo_train.jsonl
knowlyr-hub export --format grpo -t trajectories.jsonl -o grpo_train.jsonl
Agent 增强数据格式
启用 agent_format=True 时,支持结构化步骤数据:
{
"instruction": "Fix the off-by-one bug in sort function",
"input": "{\"repo\": \"owner/repo\"}",
"steps": [
{"thought": "Read the file", "action": "read_file /sort.py", "observation": "def sort(arr): ...", "reward": 0.7},
{"thought": "Fix the bug", "action": "edit_file /sort.py", "observation": "File edited", "reward": 0.9}
],
"task_id": "task-001",
"reward": 0.85
}
也兼容平文本 response 字段(自动解析 Step N: / Thought: / Action: / Observation: 格式)。
Agent 训练增强
标准 SFT/DPO/GRPO 之外,针对 Agent 长程任务的 6 项增强:
1. 多轮对话格式 (agent_format)
将轨迹从平文本转为结构化多轮对话:
user: Fix the bug in sort.py ← 任务描述(不参与 loss)
assistant: Thought: Read the file ← 模型输出(参与 loss ✓)
Action: read_file /sort.py
user: Observation: def sort(arr)... ← 环境反馈(不参与 loss)
assistant: Thought: Fix the comparison ← 模型输出(参与 loss ✓)
Action: edit_file /sort.py
user: Observation: File edited ← 环境反馈(不参与 loss)
2. 观察遮蔽 (mask_observations)
只对模型生成的 thought + action token 计算 loss,环境返回的 observation token 设为 labels=-100。避免模型学习「预测环境行为」,专注于「学习决策」。
3. 步骤级 reward 加权 (step_weighted_loss)
使用 knowlyr-reward 的步骤级 process reward 加权每个 token 的 CE loss:
loss_token = CE(token) × (step_reward / mean_reward)
好的步骤获得更高权重,差的步骤被降权。
4. 长轨迹分块 (chunk_long_trajectories)
超过 max_length 的轨迹按步骤边界拆分为多个训练样本。每个 chunk 包含任务描述 + 上下文步骤 + 当前段,不在步骤中间断开。
5. 课程学习 (curriculum)
从简单(短轨迹/高 reward)到困难(长轨迹/低 reward)渐进式训练:
curriculum: true
curriculum_start_ratio: 0.3 # 初始使用 30% 最简单样本
curriculum_warmup_epochs: 2 # 2 个 epoch 后使用全部数据
6. 步骤级 GRPO (step_level_advantage)
在 GRPO 的轨迹级 advantage 基础上,用步骤 reward 进一步加权:
A_{i,j} = A_trajectory_i × (r_{step_j} / mean(r_steps))
好的轨迹中的好步骤获得更大正梯度,差的轨迹中的差步骤受到更大惩罚。
训练方法
| 方法 | 用途 | 数据格式 | CLI |
|---|---|---|---|
| SFT | 监督微调 | instruction/response JSONL | knowlyr-trainer sft |
| DPO | 偏好对齐 | prompt/chosen/rejected JSONL | knowlyr-trainer dpo |
| GRPO | 组内策略优化 | prompt + 多条轨迹 JSONL | knowlyr-trainer grpo |
功能矩阵
| 功能 | SFT | DPO | GRPO |
|---|---|---|---|
| 多轮对话格式 | ✅ | — | — |
| 观察遮蔽 | ✅ | — | — |
| 步骤加权 loss | ✅ | — | — |
| 长轨迹分块 | ✅ | — | — |
| 课程学习 | ✅ | — | — |
| 步骤级 advantage | — | — | ✅ |
| LoRA | ✅ | ✅ | ✅ |
| bf16 混合精度 | ✅ | ✅ | ✅ |
| Checkpoint 保存 | ✅ | ✅ | ✅ |
| wandb 日志 | ✅ | ✅ | ✅ |
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file knowlyr_trainer-0.1.1.tar.gz.
File metadata
- Download URL: knowlyr_trainer-0.1.1.tar.gz
- Upload date:
- Size: 52.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
89c42cda444995ad0291312269f866fff8aaacb82f137a0836e8c081e7fecc76
|
|
| MD5 |
6571a68d05d563aef80dcea59ae1422a
|
|
| BLAKE2b-256 |
76858f6d672cb86ccc3956ec1c5cefab2816bf586189050ef1a10febaf1de3c2
|
File details
Details for the file knowlyr_trainer-0.1.1-py3-none-any.whl.
File metadata
- Download URL: knowlyr_trainer-0.1.1-py3-none-any.whl
- Upload date:
- Size: 46.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6b5fb88ef6cb02e8ae09887be894ab4cb30a894a3898d66eb213f6fb7cafd7d4
|
|
| MD5 |
79a98f7a447319da4bd04f670578d4af
|
|
| BLAKE2b-256 |
2193eaaa306a20e3d9c15790dfaf5b71137ac95ef904e742ea220a33e25f9c9b
|