Self-Distilled Policy Optimization (SDPO) for TRL — faithful reimplementation of arxiv:2601.20802
Project description
sdpo-rl
Self-Distilled Policy Optimization (SDPO) for Hugging Face TRL
A faithful reimplementation of SDPO (arxiv:2601.20802) from the lasgroup/SDPO verl fork, ported to the Hugging Face TRL ecosystem as a drop-in GRPOTrainer subclass and included support for Unsloth.
What is SDPO?
The Core Idea
- When you give a model a good example in its prompt, it performs better.
- Give it two good examples and it performs even better.
- SDPO figured out how to distill that improvement back into the base model — so it behaves as if it always has those examples, without actually needing them in the prompt.
- This process repeats, compounding the model's capability over successive rounds.
Where Do High-Quality Examples Come From?
- Environment feedback: Feedback from some environment (e.g., a Python interpreter)
- Human-authored work: Handwritten worked examples, reasoning traces, or notes
- Best-of-N sampling: Prompt the model multiple times for the same task, then keep whichever answer passes your evaluation criteria
- Distillation from a stronger model: Stronger model reasoning traces
- Deep research pipelines: Answers or relevant context fetched from reputable sources
The higher the quality of your examples, the more powerful your model. Think creatively about where you can get high-quality examples. Multiple examples drive up quality. If adding an example to the context improves the answer, SDPO pushes towards higher quality; if it degrades the answer, SDPO pushes towards lower quality. Be careful about the quality you train with.
How Do You Measure Example Quality?
- Pass-rate scoring: Generate hundreds of outputs per problem, measure what fraction are correct. Higher pass rate = stronger example.
- Reward models / LLM-as-judge: A separate model scores each output. A/B comparison scoring tends to work better than absolute ratings.
- Manual review: Read the output and assess whether it's higher quality.
Additional Benefits
SDPO offers advantages such as:
- Faster training (denser reward signal)
- Shorter reasoning traces from better credit assignment (related)
- Reduced forgetting when jumping between datasets
The Algorithm
SDPO replaces GRPO's scalar reward loss with token-level self-distillation from a teacher model. On each batch:
- Generate multiple completions per prompt
- Evaluate with your reward function (scores + optional feedback)
- Select successful peer rollouts as demonstrations
- Reprompt the teacher with the task + peer demo + error feedback
- Distill the student towards the teacher via top-K KL divergence
- Update the EMA teacher weights
The teacher sees "here's a working solution and the error from your last attempt" while the student learns to match that improved distribution. No third model in memory -- the teacher is the ref_model updated via EMA.
Quick Start
pip install sdpo-rl
from sdpo_rl import SDPOTrainer, SDPOConfig
from trl import GRPOConfig
grpo_config = GRPOConfig(
output_dir="./output",
num_generations=4,
max_completion_length=128,
learning_rate=1e-5,
bf16=True,
remove_unused_columns=False,
)
sdpo_config = SDPOConfig() # sensible defaults from the paper
trainer = SDPOTrainer(
model=model,
args=grpo_config,
sdpo_config=sdpo_config,
processing_class=tokenizer,
reward_funcs=[reward_fn],
train_dataset=dataset,
)
trainer.train()
With Unsloth (2x faster, 60% less memory)
from unsloth import FastLanguageModel, PatchFastRL
# CRITICAL: Patch BEFORE importing SDPOTrainer
PatchFastRL("GRPO", FastLanguageModel)
from sdpo_rl import SDPOTrainer, SDPOConfig
model, tokenizer = FastLanguageModel.from_pretrained(
"Qwen/Qwen2.5-7B-Instruct",
load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(model, r=16, ...)
trainer = SDPOTrainer(
model=model,
args=grpo_config,
sdpo_config=sdpo_config,
processing_class=tokenizer,
reward_funcs=[reward_fn],
train_dataset=dataset,
)
trainer.train()
Configuration
SDPOConfig Reference
Parameters match the lasgroup/SDPO reference (see Known Limitations for gaps). Defaults are from the paper's experiment scripts. See docs/configuration.md for the full reference.
SDPOConfig(
# --- Loss mode ---
enabled=True, # True: SDPO replaces GRPO loss. False: vanilla GRPO.
# --- KL divergence ---
alpha=0.5, # 0.0=forward KL, 0.5=JSD (paper default), 1.0=reverse KL
full_logit_distillation=True, # Use top-K logits (True) or token-level KL (False)
distillation_topk=100, # K for top-K approximation
distillation_add_tail=True, # Append tail bucket for residual probability mass
# --- Importance sampling ---
is_clip=2.0, # Clamp IS ratio. None disables correction.
# --- Teacher ---
teacher_mode="ema", # "ema", "frozen", or "lora_ema" (trust_region not yet implemented)
teacher_update_rate=0.05, # EMA rate: teacher = (1-rate)*teacher + rate*student
# --- Chat template ---
apply_chat_template_kwargs={}, # Extra kwargs for tokenizer.apply_chat_template()
# --- Demonstration selection ---
success_reward_threshold=1.0, # Min reward for a rollout to be a "successful" demo
dont_reprompt_on_self_success=True, # Exclude own response as demonstration
remove_thinking_from_demonstration=True, # Strip <think>...</think> from demos
# --- Feedback ---
include_environment_feedback=True, # Include env feedback (test errors, etc.) in teacher prompt
environment_feedback_only_without_solution=True, # Only include feedback when no peer demo exists
# --- Reprompting ---
max_reprompt_length=10240, # Max tokens for teacher prompt
reprompt_truncation="right", # "left", "right", or "error"
# --- Templates (customizable) ---
reprompt_template="{prompt}{solution}{feedback}\n\nCorrectly solve the original question.\n",
solution_template="\nCorrect solution:\n\n{successful_previous_attempt}\n\n",
feedback_template="\nThe following is feedback from your unsuccessful earlier attempt:\n\n{feedback_raw}\n\n",
)
Reward Functions
TRL requires reward functions to return list[float]. To provide feedback for SDPO's teacher prompts, use a callable class with a last_feedback attribute:
class MyReward:
"""Reward function that provides scores AND feedback for SDPO."""
def __init__(self):
self.last_feedback: list[str] = []
def __call__(self, prompts, completions, **kwargs) -> list[float]:
scores = []
self.last_feedback = []
for prompt, completion in zip(prompts, completions):
if is_correct(completion):
scores.append(1.0)
self.last_feedback.append("")
else:
scores.append(0.0)
self.last_feedback.append(f"Wrong: expected {answer}, got {completion}")
return scores # TRL requires list[float]
reward_fn = MyReward()
trainer = SDPOTrainer(..., reward_funcs=[reward_fn])
The trainer checks hasattr(rf, "last_feedback") on each reward function after scoring. Feedback strings are injected into teacher prompts for samples that lack a successful peer demonstration.
If you don't need feedback (simpler tasks), a plain function returning list[float] works fine -- SDPO still uses peer demonstrations from successful rollouts.
Training Modes
SDPO (default): Self-distillation replaces GRPO loss entirely.
sdpo_config = SDPOConfig(enabled=True) # the default
GRPO fallback: Disable SDPO, use standard TRL GRPO.
sdpo_config = SDPOConfig(enabled=False)
Frozen teacher: Use initial weights as teacher (no EMA updates).
sdpo_config = SDPOConfig(teacher_mode="frozen")
LoRA EMA teacher: Memory-efficient mode for PEFT/LoRA models. Keeps student and teacher as two LoRA adapters on a shared base model instead of deepcopying the entire model. Saves ~3-4 GB for 7B QLoRA.
from peft import LoraConfig, get_peft_model
model = get_peft_model(base_model, LoraConfig(r=16, ...))
sdpo_config = SDPOConfig(teacher_mode="lora_ema")
KL Divergence Variants
SDPOConfig(alpha=0.0) # Forward KL: KL(teacher || student) -- mode-covering
SDPOConfig(alpha=0.5) # JSD: symmetric Jensen-Shannon (paper default)
SDPOConfig(alpha=1.0) # Reverse KL: KL(student || teacher) -- mode-seeking
The paper uses alpha=0.5 (JSD) across its experiments. The reference experiment scripts use alpha=1.0 with distillation_topk=20 for LiveCodeBench.
How It Works
See docs/how-it-works.md for a deeper dive into the algorithm, architecture decisions, and numerical details.
┌─────────────────────────────────────────┐
│ For each batch: │
│ │
Prompt ──┬──> Generate 4 completions ──> Reward function │
│ │ │ │
│ [comp1, comp2, comp3, comp4] │ │
│ │ [1.0, 0.0, 0.0, 0.5] │
│ │ │ │
│ Select best peer (comp1) ──────┘ │
│ │ │
│ Build teacher prompt: │
│ "Task: {prompt} │
│ Correct solution: {comp1} │
│ Feedback: {error from comp2} │
│ Correctly solve the original question." │
│ │ │
│ Teacher forward pass (with reprompted input) │
│ Student forward pass (with original input) │
│ │ │
│ KL(student top-K || teacher top-K) ──> loss │
│ │ │
│ EMA update: teacher <- 0.95*teacher + 0.05*student
└─────────────────────────────────────────────────────┘
Examples
See examples/ for complete, runnable scripts:
| Example | Task | Key Feature |
|---|---|---|
basic_sdpo.py |
Math (addition) | Core SDPO loop with feedback |
sdpo_lora_ema.py |
Math (multiplication) | LoRA EMA teacher mode (memory-efficient) |
sdpo_with_unsloth.py |
Reasoning | Unsloth + QLoRA + 4-bit |
sdpo_rich_feedback.py |
Code generation | Test execution with error messages |
python examples/basic_sdpo.py
Benchmark
The benchmark/ directory contains a full MBPP code generation benchmark (Qwen2.5-0.5B, 4-bit QLoRA, RTX 3080):
- Correctness: Max |our_loss - ref_loss| = 1.4e-08 across 200 training steps (verified against verl reference)
- Performance: SDPO 1.95% pass@1 vs GRPO 1.17% at step 200
See benchmark/README.md for full results and replication instructions.
Documentation
| Document | What it covers |
|---|---|
| docs/configuration.md | Full SDPOConfig reference, training modes, reward functions |
| docs/how-it-works.md | Algorithm deep-dive, architecture decisions, numerical details |
| docs/unsloth.md | Unsloth compatibility, import order, what works |
| docs/verification.md | Line-by-line verification against verl reference |
| docs/deviations.md | Intentional differences from verl (TRL adaptation) |
| examples/README.md | Example walkthrough and customization guide |
| benchmark/README.md | MBPP benchmark methodology and results |
GPU Requirements
| Model Size | Standard EMA | LoRA EMA | Unsloth (4-bit) |
|---|---|---|---|
| 0.5B | ~6 GB | ~4 GB | ~3.5 GB |
| 7B | ~28 GB | ~14 GB | ~10 GB |
| 14B | ~56 GB | ~30 GB | ~18 GB |
Known Limitations
Two features from the paper / reference are not yet implemented:
-
Hybrid SDPO+GRPO blending (Section 4.5): The paper defines a combined advantage
A = λ·A_GRPO + (1-λ)·A_SDPOthat interpolates between GRPO and SDPO. Our library supportsenabled=True(pure SDPO) orenabled=False(pure GRPO) but not lambda blending. The paper shows this hybrid helps weaker models (e.g., Qwen3-0.6B). -
trust_regionteacher mode: Declared in config validation but raisesNotImplementedError. All paper experiments use EMA, so this is low priority.
Citation
@article{hubotter2026sdpo,
title={Reinforcement Learning via Self-Distillation},
author={H{\"u}botter, Jonas and L{\"u}beck, Frederike and Behric, Lejs and Baumann, Anton and Bagatella, Marco and Marta, Daniel and Hakimi, Ido and Shenfeld, Idan and Kleine Buening, Thomas and Guestrin, Carlos and Krause, Andreas},
journal={arXiv preprint arXiv:2601.20802},
year={2026}
}
License
Apache 2.0
Acknowledgments
- lasgroup/SDPO -- Reference implementation
- Hugging Face TRL -- Base GRPO trainer
- Unsloth -- Training optimizations
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file sdpo_rl-0.0.3.tar.gz.
File metadata
- Download URL: sdpo_rl-0.0.3.tar.gz
- Upload date:
- Size: 1.1 MB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4a53fa18329d1eb81b36822d1b49affdf8e1b5ce12b8704637caac88e96e9a01
|
|
| MD5 |
64f250471d5ca7b3ed34603924bdb6f3
|
|
| BLAKE2b-256 |
d659d41f9da07164c2bada594c80c7178ac8ac8cddf4e06ab0b50ecd2d3e3630
|
Provenance
The following attestation bundles were made for sdpo_rl-0.0.3.tar.gz:
Publisher:
publish.yml on SethBurkart123/sdpo
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
sdpo_rl-0.0.3.tar.gz -
Subject digest:
4a53fa18329d1eb81b36822d1b49affdf8e1b5ce12b8704637caac88e96e9a01 - Sigstore transparency entry: 939461874
- Sigstore integration time:
-
Permalink:
SethBurkart123/sdpo@7782798687e9502c016bfc15b679e807f937f43e -
Branch / Tag:
refs/tags/v0.0.3 - Owner: https://github.com/SethBurkart123
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@7782798687e9502c016bfc15b679e807f937f43e -
Trigger Event:
push
-
Statement type:
File details
Details for the file sdpo_rl-0.0.3-py3-none-any.whl.
File metadata
- Download URL: sdpo_rl-0.0.3-py3-none-any.whl
- Upload date:
- Size: 29.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c11877162eea375ff4922fa97d813e2b6833d615912c48a8ee2d185db9c9f4f6
|
|
| MD5 |
b0bdae6a23727769784a20a2f816bdb6
|
|
| BLAKE2b-256 |
c33295cfc25a705d252126da14ee40076bcedfecf71dfb031e34dd88972153d8
|
Provenance
The following attestation bundles were made for sdpo_rl-0.0.3-py3-none-any.whl:
Publisher:
publish.yml on SethBurkart123/sdpo
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
sdpo_rl-0.0.3-py3-none-any.whl -
Subject digest:
c11877162eea375ff4922fa97d813e2b6833d615912c48a8ee2d185db9c9f4f6 - Sigstore transparency entry: 939461882
- Sigstore integration time:
-
Permalink:
SethBurkart123/sdpo@7782798687e9502c016bfc15b679e807f937f43e -
Branch / Tag:
refs/tags/v0.0.3 - Owner: https://github.com/SethBurkart123
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@7782798687e9502c016bfc15b679e807f937f43e -
Trigger Event:
push
-
Statement type: