Skip to main content

Self-Distilled Policy Optimization (SDPO) for TRL — faithful reimplementation of arxiv:2601.20802

Project description

[!WARNING]
Currently under development, not guaranteed to be at perfect parity with the SDPO paper.

sdpo-rl

Self-Distilled Policy Optimization (SDPO) for Hugging Face TRL

A faithful reimplementation of SDPO (arxiv:2601.20802) from the lasgroup/SDPO verl fork, ported to the Hugging Face TRL ecosystem as a drop-in GRPOTrainer subclass.

Tests Python License

What is SDPO?

SDPO (Self-Distilled Policy Optimization) replaces GRPO's scalar reward loss with token-level self-distillation from a teacher model. On each batch:

  1. Generate multiple completions per prompt
  2. Evaluate with your reward function (scores + optional feedback)
  3. Select successful peer rollouts as demonstrations
  4. Reprompt the teacher with the task + peer demo + error feedback
  5. Distill the student towards the teacher via top-K KL divergence
  6. Update the EMA teacher weights

The teacher sees "here's a working solution and the error from your last attempt" while the student learns to match that improved distribution. No third model in memory -- the teacher is the ref_model updated via EMA.

Quick Start

pip install sdpo-rl
from sdpo_rl import SDPOTrainer, SDPOConfig
from trl import GRPOConfig

grpo_config = GRPOConfig(
    output_dir="./output",
    num_generations=4,
    max_completion_length=128,
    learning_rate=1e-5,
    bf16=True,
    remove_unused_columns=False,
)

sdpo_config = SDPOConfig()  # sensible defaults from the paper

trainer = SDPOTrainer(
    model=model,
    args=grpo_config,
    sdpo_config=sdpo_config,
    processing_class=tokenizer,
    reward_funcs=[reward_fn],
    train_dataset=dataset,
)

trainer.train()

With Unsloth (2x faster, 60% less memory)

from unsloth import FastLanguageModel, PatchFastRL

# CRITICAL: Patch BEFORE importing SDPOTrainer
PatchFastRL("GRPO", FastLanguageModel)

from sdpo_rl import SDPOTrainer, SDPOConfig

model, tokenizer = FastLanguageModel.from_pretrained(
    "Qwen/Qwen2.5-7B-Instruct",
    load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(model, r=16, ...)

trainer = SDPOTrainer(
    model=model,
    args=grpo_config,
    sdpo_config=sdpo_config,
    processing_class=tokenizer,
    reward_funcs=[reward_fn],
    train_dataset=dataset,
)
trainer.train()

Configuration

SDPOConfig Reference

Parameters match the lasgroup/SDPO reference (see Known Limitations for gaps). Defaults are from the paper's experiment scripts.

SDPOConfig(
    # --- Loss mode ---
    enabled=True,                              # True: SDPO replaces GRPO loss. False: vanilla GRPO.

    # --- KL divergence ---
    alpha=0.5,                                 # 0.0=forward KL, 0.5=JSD (paper default), 1.0=reverse KL
    full_logit_distillation=True,              # Use top-K logits (True) or token-level KL (False)
    distillation_topk=100,                     # K for top-K approximation
    distillation_add_tail=True,                # Append tail bucket for residual probability mass

    # --- Importance sampling ---
    is_clip=2.0,                               # Clamp IS ratio. None disables correction.

    # --- Teacher ---
    teacher_mode="ema",                        # "ema" or "frozen" (trust_region declared but not yet implemented)
    teacher_update_rate=0.05,                  # EMA rate: teacher = (1-rate)*teacher + rate*student

    # --- Demonstration selection ---
    success_reward_threshold=1.0,              # Min reward for a rollout to be a "successful" demo
    dont_reprompt_on_self_success=True,        # Exclude own response as demonstration
    remove_thinking_from_demonstration=True,   # Strip <think>...</think> from demos

    # --- Feedback ---
    include_environment_feedback=True,         # Include env feedback (test errors, etc.) in teacher prompt
    environment_feedback_only_without_solution=True,  # Only include feedback when no peer demo exists

    # --- Reprompting ---
    max_reprompt_length=10240,                 # Max tokens for teacher prompt
    reprompt_truncation="right",               # "left", "right", or "error"

    # --- Templates (customizable) ---
    reprompt_template="{prompt}{solution}{feedback}\n\nCorrectly solve the original question.\n",
    solution_template="\nCorrect solution:\n\n{successful_previous_attempt}\n\n",
    feedback_template="\nThe following is feedback from your unsuccessful earlier attempt:\n\n{feedback_raw}\n\n",
)

Reward Functions

TRL requires reward functions to return list[float]. To provide feedback for SDPO's teacher prompts, use a callable class with a last_feedback attribute:

class MyReward:
    """Reward function that provides scores AND feedback for SDPO."""

    def __init__(self):
        self.last_feedback: list[str] = []

    def __call__(self, prompts, completions, **kwargs) -> list[float]:
        scores = []
        self.last_feedback = []

        for prompt, completion in zip(prompts, completions):
            if is_correct(completion):
                scores.append(1.0)
                self.last_feedback.append("")
            else:
                scores.append(0.0)
                self.last_feedback.append(f"Wrong: expected {answer}, got {completion}")

        return scores  # TRL requires list[float]

reward_fn = MyReward()
trainer = SDPOTrainer(..., reward_funcs=[reward_fn])

The trainer checks hasattr(rf, "last_feedback") on each reward function after scoring. Feedback strings are injected into teacher prompts for samples that lack a successful peer demonstration.

If you don't need feedback (simpler tasks), a plain function returning list[float] works fine -- SDPO still uses peer demonstrations from successful rollouts.

Training Modes

SDPO (default): Self-distillation replaces GRPO loss entirely.

sdpo_config = SDPOConfig(enabled=True)  # the default

GRPO fallback: Disable SDPO, use standard TRL GRPO.

sdpo_config = SDPOConfig(enabled=False)

Frozen teacher: Use initial weights as teacher (no EMA updates).

sdpo_config = SDPOConfig(teacher_mode="frozen")

KL Divergence Variants

SDPOConfig(alpha=0.0)   # Forward KL: KL(teacher || student) -- mode-covering
SDPOConfig(alpha=0.5)   # JSD: symmetric Jensen-Shannon (paper default)
SDPOConfig(alpha=1.0)   # Reverse KL: KL(student || teacher) -- mode-seeking

The paper uses alpha=0.5 (JSD) across its experiments. The reference experiment scripts use alpha=1.0 with distillation_topk=20 for LiveCodeBench.

How It Works

                     ┌─────────────────────────────────────────┐
                     │  For each batch:                        │
                     │                                         │
 Prompt ──┬──> Generate 4 completions ──> Reward function     │
           │         │                         │               │
           │    [comp1, comp2, comp3, comp4]   │               │
           │         │                    [1.0, 0.0, 0.0, 0.5] │
           │         │                         │               │
           │    Select best peer (comp1) ──────┘               │
           │         │                                         │
           │    Build teacher prompt:                          │
           │      "Task: {prompt}                              │
           │       Correct solution: {comp1}                   │
           │       Feedback: {error from comp2}                │
           │       Correctly solve the original question."     │
           │         │                                         │
           │    Teacher forward pass (with reprompted input)   │
           │    Student forward pass (with original input)     │
           │         │                                         │
           │    KL(student top-K || teacher top-K) ──> loss    │
           │         │                                         │
           │    EMA update: teacher <- 0.95*teacher + 0.05*student
           └─────────────────────────────────────────────────────┘

Examples

See examples/ for complete, runnable scripts:

Example Task Key Feature
basic_sdpo.py Math (addition) Core SDPO loop with feedback
sdpo_with_unsloth.py Reasoning Unsloth + QLoRA + 4-bit
sdpo_rich_feedback.py Code generation Test execution with error messages
python examples/basic_sdpo.py

Benchmark

The benchmark/ directory contains a full MBPP code generation benchmark (Qwen2.5-0.5B, 4-bit QLoRA, RTX 3080):

  • Correctness: Max |our_loss - ref_loss| = 1.4e-08 across 200 training steps (verified against verl reference)
  • Performance: SDPO 1.95% pass@1 vs GRPO 1.17% at step 200

See benchmark/README.md for full results and replication instructions.

Documentation

Document What it covers
VERIFICATION.md Line-by-line verification against verl reference
DEVIATIONS.md Intentional differences from verl (TRL adaptation)
HANDOVER.md Architecture decisions, gotchas, implementation guide
UNSLOTH_INTEGRATION.md Unsloth compatibility, import order, what works
examples/README.md Example walkthrough and customization guide
benchmark/README.md MBPP benchmark methodology and results

GPU Requirements

Model Size Without Unsloth With Unsloth (4-bit)
0.5B ~6 GB ~3.5 GB
7B ~28 GB ~10 GB
14B ~56 GB ~18 GB

Known Limitations

Two features from the paper / reference are not yet implemented:

  1. Hybrid SDPO+GRPO blending (Section 4.5): The paper defines a combined advantage A = λ·A_GRPO + (1-λ)·A_SDPO that interpolates between GRPO and SDPO. Our library supports enabled=True (pure SDPO) or enabled=False (pure GRPO) but not lambda blending. The paper shows this hybrid helps weaker models (e.g., Qwen3-0.6B).

  2. trust_region teacher mode: Declared in config validation but raises NotImplementedError. All paper experiments use EMA, so this is low priority.

Citation

@article{hubotter2026sdpo,
  title={Reinforcement Learning via Self-Distillation},
  author={H{\"u}botter, Jonas and L{\"u}beck, Frederike and Behric, Lejs and Baumann, Anton and Bagatella, Marco and Marta, Daniel and Hakimi, Ido and Shenfeld, Idan and Kleine Buening, Thomas and Guestrin, Carlos and Krause, Andreas},
  journal={arXiv preprint arXiv:2601.20802},
  year={2026}
}

License

Apache 2.0

Acknowledgments

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sdpo_rl-0.0.2.tar.gz (1.1 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sdpo_rl-0.0.2-py3-none-any.whl (24.8 kB view details)

Uploaded Python 3

File details

Details for the file sdpo_rl-0.0.2.tar.gz.

File metadata

  • Download URL: sdpo_rl-0.0.2.tar.gz
  • Upload date:
  • Size: 1.1 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for sdpo_rl-0.0.2.tar.gz
Algorithm Hash digest
SHA256 e7efa7db20792e34dcab3c9f7449e7a4c8b90a5391b1bf3a40e9c0bf7c83febb
MD5 95c1086dd918f6d7c60afad36119c111
BLAKE2b-256 17930be3b60471322fc2c7ee2dae390890ffec4a471354f3428b1eb8e2f00ab3

See more details on using hashes here.

Provenance

The following attestation bundles were made for sdpo_rl-0.0.2.tar.gz:

Publisher: publish.yml on SethBurkart123/sdpo

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file sdpo_rl-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: sdpo_rl-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 24.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for sdpo_rl-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 a5b51024e136eb7669dd82c35c5a93b85a9298294c7b98f6d026be63a5cf24f6
MD5 e7b6def3e68c7df37a4dd10dc24c1c73
BLAKE2b-256 83015b6493407edff606ce46cd8660d05271dac1b483ee3f4642981688aede49

See more details on using hashes here.

Provenance

The following attestation bundles were made for sdpo_rl-0.0.2-py3-none-any.whl:

Publisher: publish.yml on SethBurkart123/sdpo

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page