Self-Distilled Policy Optimization (SDPO) for TRL — faithful reimplementation of arxiv:2601.20802
Project description
[!WARNING]
Currently under development, not guaranteed to be at perfect parity with the SDPO paper.
sdpo-rl
Self-Distilled Policy Optimization (SDPO) for Hugging Face TRL
A faithful reimplementation of SDPO (arxiv:2601.20802) from the lasgroup/SDPO verl fork, ported to the Hugging Face TRL ecosystem as a drop-in GRPOTrainer subclass.
What is SDPO?
SDPO (Self-Distilled Policy Optimization) replaces GRPO's scalar reward loss with token-level self-distillation from a teacher model. On each batch:
- Generate multiple completions per prompt
- Evaluate with your reward function (scores + optional feedback)
- Select successful peer rollouts as demonstrations
- Reprompt the teacher with the task + peer demo + error feedback
- Distill the student towards the teacher via top-K KL divergence
- Update the EMA teacher weights
The teacher sees "here's a working solution and the error from your last attempt" while the student learns to match that improved distribution. No third model in memory -- the teacher is the ref_model updated via EMA.
Quick Start
pip install sdpo-rl
from sdpo_rl import SDPOTrainer, SDPOConfig
from trl import GRPOConfig
grpo_config = GRPOConfig(
output_dir="./output",
num_generations=4,
max_completion_length=128,
learning_rate=1e-5,
bf16=True,
remove_unused_columns=False,
)
sdpo_config = SDPOConfig() # sensible defaults from the paper
trainer = SDPOTrainer(
model=model,
args=grpo_config,
sdpo_config=sdpo_config,
processing_class=tokenizer,
reward_funcs=[reward_fn],
train_dataset=dataset,
)
trainer.train()
With Unsloth (2x faster, 60% less memory)
from unsloth import FastLanguageModel, PatchFastRL
# CRITICAL: Patch BEFORE importing SDPOTrainer
PatchFastRL("GRPO", FastLanguageModel)
from sdpo_rl import SDPOTrainer, SDPOConfig
model, tokenizer = FastLanguageModel.from_pretrained(
"Qwen/Qwen2.5-7B-Instruct",
load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(model, r=16, ...)
trainer = SDPOTrainer(
model=model,
args=grpo_config,
sdpo_config=sdpo_config,
processing_class=tokenizer,
reward_funcs=[reward_fn],
train_dataset=dataset,
)
trainer.train()
Configuration
SDPOConfig Reference
Parameters match the lasgroup/SDPO reference (see Known Limitations for gaps). Defaults are from the paper's experiment scripts.
SDPOConfig(
# --- Loss mode ---
enabled=True, # True: SDPO replaces GRPO loss. False: vanilla GRPO.
# --- KL divergence ---
alpha=0.5, # 0.0=forward KL, 0.5=JSD (paper default), 1.0=reverse KL
full_logit_distillation=True, # Use top-K logits (True) or token-level KL (False)
distillation_topk=100, # K for top-K approximation
distillation_add_tail=True, # Append tail bucket for residual probability mass
# --- Importance sampling ---
is_clip=2.0, # Clamp IS ratio. None disables correction.
# --- Teacher ---
teacher_mode="ema", # "ema" or "frozen" (trust_region declared but not yet implemented)
teacher_update_rate=0.05, # EMA rate: teacher = (1-rate)*teacher + rate*student
# --- Demonstration selection ---
success_reward_threshold=1.0, # Min reward for a rollout to be a "successful" demo
dont_reprompt_on_self_success=True, # Exclude own response as demonstration
remove_thinking_from_demonstration=True, # Strip <think>...</think> from demos
# --- Feedback ---
include_environment_feedback=True, # Include env feedback (test errors, etc.) in teacher prompt
environment_feedback_only_without_solution=True, # Only include feedback when no peer demo exists
# --- Reprompting ---
max_reprompt_length=10240, # Max tokens for teacher prompt
reprompt_truncation="right", # "left", "right", or "error"
# --- Templates (customizable) ---
reprompt_template="{prompt}{solution}{feedback}\n\nCorrectly solve the original question.\n",
solution_template="\nCorrect solution:\n\n{successful_previous_attempt}\n\n",
feedback_template="\nThe following is feedback from your unsuccessful earlier attempt:\n\n{feedback_raw}\n\n",
)
Reward Functions
TRL requires reward functions to return list[float]. To provide feedback for SDPO's teacher prompts, use a callable class with a last_feedback attribute:
class MyReward:
"""Reward function that provides scores AND feedback for SDPO."""
def __init__(self):
self.last_feedback: list[str] = []
def __call__(self, prompts, completions, **kwargs) -> list[float]:
scores = []
self.last_feedback = []
for prompt, completion in zip(prompts, completions):
if is_correct(completion):
scores.append(1.0)
self.last_feedback.append("")
else:
scores.append(0.0)
self.last_feedback.append(f"Wrong: expected {answer}, got {completion}")
return scores # TRL requires list[float]
reward_fn = MyReward()
trainer = SDPOTrainer(..., reward_funcs=[reward_fn])
The trainer checks hasattr(rf, "last_feedback") on each reward function after scoring. Feedback strings are injected into teacher prompts for samples that lack a successful peer demonstration.
If you don't need feedback (simpler tasks), a plain function returning list[float] works fine -- SDPO still uses peer demonstrations from successful rollouts.
Training Modes
SDPO (default): Self-distillation replaces GRPO loss entirely.
sdpo_config = SDPOConfig(enabled=True) # the default
GRPO fallback: Disable SDPO, use standard TRL GRPO.
sdpo_config = SDPOConfig(enabled=False)
Frozen teacher: Use initial weights as teacher (no EMA updates).
sdpo_config = SDPOConfig(teacher_mode="frozen")
KL Divergence Variants
SDPOConfig(alpha=0.0) # Forward KL: KL(teacher || student) -- mode-covering
SDPOConfig(alpha=0.5) # JSD: symmetric Jensen-Shannon (paper default)
SDPOConfig(alpha=1.0) # Reverse KL: KL(student || teacher) -- mode-seeking
The paper uses alpha=0.5 (JSD) across its experiments. The reference experiment scripts use alpha=1.0 with distillation_topk=20 for LiveCodeBench.
How It Works
┌─────────────────────────────────────────┐
│ For each batch: │
│ │
Prompt ──┬──> Generate 4 completions ──> Reward function │
│ │ │ │
│ [comp1, comp2, comp3, comp4] │ │
│ │ [1.0, 0.0, 0.0, 0.5] │
│ │ │ │
│ Select best peer (comp1) ──────┘ │
│ │ │
│ Build teacher prompt: │
│ "Task: {prompt} │
│ Correct solution: {comp1} │
│ Feedback: {error from comp2} │
│ Correctly solve the original question." │
│ │ │
│ Teacher forward pass (with reprompted input) │
│ Student forward pass (with original input) │
│ │ │
│ KL(student top-K || teacher top-K) ──> loss │
│ │ │
│ EMA update: teacher <- 0.95*teacher + 0.05*student
└─────────────────────────────────────────────────────┘
Examples
See examples/ for complete, runnable scripts:
| Example | Task | Key Feature |
|---|---|---|
basic_sdpo.py |
Math (addition) | Core SDPO loop with feedback |
sdpo_with_unsloth.py |
Reasoning | Unsloth + QLoRA + 4-bit |
sdpo_rich_feedback.py |
Code generation | Test execution with error messages |
python examples/basic_sdpo.py
Benchmark
The benchmark/ directory contains a full MBPP code generation benchmark (Qwen2.5-0.5B, 4-bit QLoRA, RTX 3080):
- Correctness: Max |our_loss - ref_loss| = 1.4e-08 across 200 training steps (verified against verl reference)
- Performance: SDPO 1.95% pass@1 vs GRPO 1.17% at step 200
See benchmark/README.md for full results and replication instructions.
Documentation
| Document | What it covers |
|---|---|
| VERIFICATION.md | Line-by-line verification against verl reference |
| DEVIATIONS.md | Intentional differences from verl (TRL adaptation) |
| HANDOVER.md | Architecture decisions, gotchas, implementation guide |
| UNSLOTH_INTEGRATION.md | Unsloth compatibility, import order, what works |
| examples/README.md | Example walkthrough and customization guide |
| benchmark/README.md | MBPP benchmark methodology and results |
GPU Requirements
| Model Size | Without Unsloth | With Unsloth (4-bit) |
|---|---|---|
| 0.5B | ~6 GB | ~3.5 GB |
| 7B | ~28 GB | ~10 GB |
| 14B | ~56 GB | ~18 GB |
Known Limitations
Two features from the paper / reference are not yet implemented:
-
Hybrid SDPO+GRPO blending (Section 4.5): The paper defines a combined advantage
A = λ·A_GRPO + (1-λ)·A_SDPOthat interpolates between GRPO and SDPO. Our library supportsenabled=True(pure SDPO) orenabled=False(pure GRPO) but not lambda blending. The paper shows this hybrid helps weaker models (e.g., Qwen3-0.6B). -
trust_regionteacher mode: Declared in config validation but raisesNotImplementedError. All paper experiments use EMA, so this is low priority.
Citation
@article{hubotter2026sdpo,
title={Reinforcement Learning via Self-Distillation},
author={H{\"u}botter, Jonas and L{\"u}beck, Frederike and Behric, Lejs and Baumann, Anton and Bagatella, Marco and Marta, Daniel and Hakimi, Ido and Shenfeld, Idan and Kleine Buening, Thomas and Guestrin, Carlos and Krause, Andreas},
journal={arXiv preprint arXiv:2601.20802},
year={2026}
}
License
Apache 2.0
Acknowledgments
- lasgroup/SDPO -- Reference implementation
- Hugging Face TRL -- Base GRPO trainer
- Unsloth -- Training optimizations
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file sdpo_rl-0.0.2.tar.gz.
File metadata
- Download URL: sdpo_rl-0.0.2.tar.gz
- Upload date:
- Size: 1.1 MB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e7efa7db20792e34dcab3c9f7449e7a4c8b90a5391b1bf3a40e9c0bf7c83febb
|
|
| MD5 |
95c1086dd918f6d7c60afad36119c111
|
|
| BLAKE2b-256 |
17930be3b60471322fc2c7ee2dae390890ffec4a471354f3428b1eb8e2f00ab3
|
Provenance
The following attestation bundles were made for sdpo_rl-0.0.2.tar.gz:
Publisher:
publish.yml on SethBurkart123/sdpo
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
sdpo_rl-0.0.2.tar.gz -
Subject digest:
e7efa7db20792e34dcab3c9f7449e7a4c8b90a5391b1bf3a40e9c0bf7c83febb - Sigstore transparency entry: 938636830
- Sigstore integration time:
-
Permalink:
SethBurkart123/sdpo@2e16fe94bedfd837b143b1ee090dc2005cf55cf6 -
Branch / Tag:
refs/tags/v0.0.2 - Owner: https://github.com/SethBurkart123
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@2e16fe94bedfd837b143b1ee090dc2005cf55cf6 -
Trigger Event:
push
-
Statement type:
File details
Details for the file sdpo_rl-0.0.2-py3-none-any.whl.
File metadata
- Download URL: sdpo_rl-0.0.2-py3-none-any.whl
- Upload date:
- Size: 24.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a5b51024e136eb7669dd82c35c5a93b85a9298294c7b98f6d026be63a5cf24f6
|
|
| MD5 |
e7b6def3e68c7df37a4dd10dc24c1c73
|
|
| BLAKE2b-256 |
83015b6493407edff606ce46cd8660d05271dac1b483ee3f4642981688aede49
|
Provenance
The following attestation bundles were made for sdpo_rl-0.0.2-py3-none-any.whl:
Publisher:
publish.yml on SethBurkart123/sdpo
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
sdpo_rl-0.0.2-py3-none-any.whl -
Subject digest:
a5b51024e136eb7669dd82c35c5a93b85a9298294c7b98f6d026be63a5cf24f6 - Sigstore transparency entry: 938636850
- Sigstore integration time:
-
Permalink:
SethBurkart123/sdpo@2e16fe94bedfd837b143b1ee090dc2005cf55cf6 -
Branch / Tag:
refs/tags/v0.0.2 - Owner: https://github.com/SethBurkart123
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@2e16fe94bedfd837b143b1ee090dc2005cf55cf6 -
Trigger Event:
push
-
Statement type: