Skip to main content

Kondo Gate: Selective backward-pass gating for policy gradient training

Project description

Kondo Gate

Selective backward-pass gating for policy gradient training. A standalone PyTorch implementation compatible with HuggingFace Transformers.

Based on arXiv:2603.20526Delightful Policy Gradients with Kondo Gating.

What it does

The Kondo gate computes delight for each training sample — the product of advantage and surprisal — then skips backward passes for low-value samples. This preserves learning quality while dramatically reducing compute.

Method Weights gradient by Backward passes per batch
PG Advantage only All B
DG sigmoid(delight) All B
DG-K Top-k by delight ~rho x B

At 3% gate rate, that means ~3 backward passes out of 100 — and it still matches or beats full DG.

Install

pip install kondo-gate

# From source with dev dependencies:
pip install -e ".[dev]"

Quick start

High-level: with HuggingFace model logits

from kondo_gate import KondoGate, KondoGateConfig

gate = KondoGate(KondoGateConfig(gate_rate=0.03))  # keep top 3%

# logits from any model (B, T, V), actions (B, T), advantages (B, T)
result = gate(logits=logits, actions=actions, advantages=advantages)
result.gated_policy_loss.backward()

KondoTrainer: drop-in training wrapper

from transformers import AutoModelForCausalLM
from kondo_gate import KondoTrainer

model = AutoModelForCausalLM.from_pretrained("gpt2")
trainer = KondoTrainer(model, gate_rate=0.03, lr=3e-4)

stats = trainer.step(
    input_ids=input_ids,
    actions=target_ids,
    advantages=advantages,
)
# stats = {"loss": ..., "gate_rate": ..., "price": ..., "mean_delight": ...}

Standalone loss functions (PG, DG)

from kondo_gate import pg_loss, dg_loss, expected_confidence_baseline

# Standard REINFORCE
loss = pg_loss(logits, actions, advantages)

# Delightful Gradient (sigmoid-weighted, all backward passes)
loss, gate_weights = dg_loss(logits, actions, advantages, eta=1.0)

# Expected confidence baseline (used in reference implementation)
baseline = expected_confidence_baseline(probs)  # b = sum pi(a)^2

Configuration

Parameter Default Description
gate_rate 0.3 Target fraction of backward passes to keep (rho). Mutually exclusive with price.
price None Fixed compute price threshold (lambda). Mutually exclusive with gate_rate.
temperature 0.1 Gate softness (eta). Used in stochastic/soft modes.
hard True Binary gating (True) vs soft sigmoid weights (False).
deterministic True Deterministic top-k selection (True, reference impl) vs Bernoulli sampling (False, Algorithm 1). Only applies when hard=True.

Three gating modes

  1. Deterministic top-k (hard=True, deterministic=True, default) — Matches the reference Colab implementation. Keeps the top rho fraction of samples ranked by delight. Binary, no randomness.

  2. Stochastic Bernoulli (hard=True, deterministic=False) — Matches Algorithm 1 in the paper. Samples G ~ Bernoulli(sigma((chi - lambda) / eta)).

  3. Soft sigmoid (hard=False) — Weights each sample by sigma((chi - lambda) / eta). All backward passes computed, gradient weighted by gate probability.

Tests

pip install -e ".[dev]"
pytest

60 tests across 10 categories:

  • Config validation (bounds, mutual exclusivity, defaults)
  • Delight computation (formula correctness, detachment, edge cases)
  • Gate mechanism (output shapes, hard/soft modes, adaptive rate targeting)
  • Full forward pass (2D/3D logits, attention masking, loss finiteness)
  • Mathematical properties (sigmoid formula, temperature limits, price monotonicity)
  • Gradient verification (flow through hard/soft gates, zero-grad for gated-out samples)
  • Integration (multi-step training loops, parameter updates)
  • Edge cases (batch=1, zero advantages, empty masks, reproducibility)
  • Deterministic mode (top-k selection, reference impl match, reproducibility)
  • Loss functions (PG, DG, DG-K structure, baseline computation)

Examples

MNIST contextual bandit (PG vs DG vs DG-K)

Replicates the paper's MNIST experiment. Requires torchvision.

pip install torchvision
python examples/mnist_bandit.py

Token reversal

Trains a small causal transformer to reverse sequences at different gate rates.

python examples/token_reversal.py

How it works

  1. Forward pass: Compute log-probabilities for taken actions, then delight = advantage x surprisal
  2. Gate decision: Set price as the (1-rho)-quantile of delight; keep samples with delight >= price
  3. Gated backward: loss = -mean(log_pi * stop_grad(gate * advantage)) — only gated-in samples contribute gradients

The gate filters out gradient noise from uninformative samples (low surprisal) and unreliable samples (low advantage magnitude), keeping only the samples that teach the most per unit of compute.

Why delight, not something simpler? Neither advantage nor surprisal alone tells the right story. High advantage with low surprisal = the model already knew. High surprisal with zero advantage = unusual but unremarkable. The multiplicative product targets the intersection: something surprising and valuable. Unlike additive combinations, the product is sign-consistent across all problem parameters (Proposition 2 in the paper).

Citation

@article{kondogate2026,
  title={Delightful Policy Gradients with Kondo Gating},
  year={2026},
  eprint={2603.20526},
  archivePrefix={arXiv},
}

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

kondo_gate-0.1.0.tar.gz (16.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

kondo_gate-0.1.0-py3-none-any.whl (9.5 kB view details)

Uploaded Python 3

File details

Details for the file kondo_gate-0.1.0.tar.gz.

File metadata

  • Download URL: kondo_gate-0.1.0.tar.gz
  • Upload date:
  • Size: 16.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.2

File hashes

Hashes for kondo_gate-0.1.0.tar.gz
Algorithm Hash digest
SHA256 f31adf31c194f6c1b88d3c89f2bdc2a627138e02f410fb0e3fb222943ad95eec
MD5 72e3fb9281820c282b468be91989dff1
BLAKE2b-256 d6065752de96d717803dfbdfecfd3bcc69d0bfa6311e5a20577cf40a843c6041

See more details on using hashes here.

File details

Details for the file kondo_gate-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: kondo_gate-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 9.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.2

File hashes

Hashes for kondo_gate-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 fb7d863b66b26b821316e19c4fd2e21298b2c84d878683a7cf46b398039db98c
MD5 82a5592008c591304d9a42c6bc7a4740
BLAKE2b-256 327aa016e0b1151c13f93c0142305565f1c944a57908310bc9033b50327d4913

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page