Skip to main content

PyTorch port of DeepMind's Disco103 meta-learned RL update rule

Project description

disco-torch

A PyTorch port of DeepMind's Disco103 — the meta-learned reinforcement learning update rule from Discovering State-Of-The-Art Reinforcement Learning Algorithms (Nature, 2025).

Disco103 is a small neural network (754K params) that replaces hand-crafted RL loss functions. Instead of PPO or GRPO, you feed it agent experience and it outputs loss targets. The agent trains by minimizing KL divergence against these targets. It was meta-trained across thousands of environments and generalizes to new tasks as a drop-in update rule.

Validated

This port achieves 100% catch rate on the reference Catch benchmark, matching the original JAX implementation:

Step   50 | catch=12%      (random)
Step  250 | catch=18%      (meta-RNN warming up)
Step  300 | catch=36%      (transition begins)
Step  350 | catch=95%      (hockey stick)
Step  400 | catch=100%     (converged)
Step 1000 | catch=100%     (stable)

Installation

pip install disco-torch

With optional extras:

pip install disco-torch[hub]       # HuggingFace Hub weight downloads
pip install disco-torch[dev]       # pytest for development

Quick start

Open in Google Colab — train Catch with Disco103 in 3 cells.

# Or run locally
python examples/catch_disco.py

# With A2C baseline for comparison
python examples/catch_disco.py --both

Using DiscoTrainer in your own project

DiscoTrainer handles meta-state management, target networks, replay buffer, optimizer, and loss computation automatically:

from disco_torch import DiscoTrainer, collect_rollout

# Your agent (must output: logits, y, z, q, aux_pi — see Agent Requirements below)
agent = YourAgent(obs_dim=64, num_actions=3).to(device)

# One line setup — downloads weights, creates optimizer, replay buffer, meta-state
trainer = DiscoTrainer(agent, device=device, lr=0.01)

# Wrap your environment
env = YourEnv(num_envs=2)
obs = env.obs()
lstm_state = agent.init_lstm_state(env.num_envs, device)

def step_fn(actions):
    rewards, dones = env.step(actions)
    return env.obs(), rewards, dones

# Training loop
for step in range(1000):
    rollout, obs, lstm_state = collect_rollout(
        agent, step_fn, obs, lstm_state, rollout_len=29, device=device,
    )
    logs = trainer.step(rollout)  # 1 gradient step per acting step

    if (step + 1) % 50 == 0:
        print(f"Step {step+1}: loss={logs['total_loss']:.4f}")

Low-level API

For advanced users who need full control, the underlying DiscoUpdateRule is also available. See examples/catch_disco.py for the complete validated training loop with all details.

Agent requirements

Your agent's forward pass must return a dict with these keys:

Key Shape Description
logits [B, A] Policy logits (unnormalized)
y [B, 600] Value prediction vector
z [B, A, 600] Per-action auxiliary prediction
q [B, A, 601] Per-action Q-value (601-bin categorical)
aux_pi [B, A, A] 1-step policy prediction

The reference agent architecture (feedforward MLP + action-conditional LSTM) is implemented in examples/catch_disco.py as DiscoMLPAgent.

Training loop details

DiscoTrainer handles all of these automatically. Listed here for reference and for users of the low-level API:

  • 1 gradient step per acting step: Effective replay ratio comes from batch_size=64 with num_envs=2 (each trajectory sampled ~32 times over its buffer lifetime)
  • ClippedAdam optimizer: Adam scaling, then per-element clip to [-1, 1], then learning rate — matching the reference's optax.chain(scale_by_adam(), clip(1.0), scale(-lr))
  • Target network: Polyak averaging with coeff=0.9 — target slowly tracks current: target = 0.9 * old_target + 0.1 * current
  • Loss masking: Mask first step of new episodes (uninformative), NOT terminal steps (most informative)
  • Meta-state persistence: The meta-network's RNN state carries across all learner steps
  • torch.no_grad() on unroll_meta_net: Required to avoid OOM

Architecture

Outer network (per-trajectory):
  y_net           MLP [600 -> 16 -> 1]       Value prediction embedding
  z_net           MLP [600 -> 16 -> 1]       Auxiliary prediction embedding
  policy_net      Conv1dNet [9 -> 16 -> 2]   Action-conditional embedding
  trajectory_rnn  LSTM(27, 256)              Reverse-unrolled over trajectory
  state_gate      Linear(128 -> 256)         Multiplicative gate from meta-RNN
  y_head/z_head   Linear(256 -> 600)         Loss targets for y and z
  pi_conv + head  Conv1dNet [258 -> 16] -> 1 Policy loss target (per action)

Meta-RNN (per-lifetime):
  input_mlp       Linear(29 -> 16)           Compress batch-time averages
  lstm            LSTMCell(16, 128)          Slow timescale adaptation

The outer network processes each trajectory with a reverse LSTM. The meta-RNN operates at a slower timescale — it sees batch-time averages and modulates the outer network via a multiplicative gate.

Package structure

disco_torch/
  __init__.py          Public API
  types.py             Dataclasses: UpdateRuleInputs, ValueOuts, etc.
  transforms.py        Input transforms and construct_input()
  meta_net.py          DiscoMetaNet — the full LSTM meta-network
  update_rule.py       DiscoUpdateRule — meta-net + value computation + loss
  value_utils.py       V-trace, Retrace Q-values, advantage estimation
  utils.py             batch_lookup, signed_logp1, 2-hot encoding, EMA
  load_weights.py      JAX/Haiku NPZ -> PyTorch weight mapping
  trainer.py           DiscoTrainer high-level API + collect_rollout + ClippedAdam

examples/
  catch_disco.py       Validated Catch benchmark with A2C baseline
  catch_colab.ipynb    Google Colab notebook (3 cells, uses DiscoTrainer)

Requirements

  • Python >= 3.11
  • PyTorch >= 2.0
  • NumPy >= 1.24

License

Apache 2.0 — same as the original disco_rl.

Citation

If you use this port, please cite the original paper:

@article{oh2025disco,
  title={Discovering State-Of-The-Art Reinforcement Learning Algorithms},
  author={Oh, Junhyuk and Farquhar, Greg and Kemaev, Iurii and Calian, Dan A. and Hessel, Matteo and Zintgraf, Luisa and Singh, Satinder and van Hasselt, Hado and Silver, David},
  journal={Nature},
  volume={648},
  pages={312--319},
  year={2025},
  doi={10.1038/s41586-025-09761-x}
}

Acknowledgments

This is a community port of google-deepmind/disco_rl. All credit for the algorithm, architecture, and pretrained weights goes to the original authors.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

disco_torch-0.4.0.tar.gz (37.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

disco_torch-0.4.0-py3-none-any.whl (34.8 kB view details)

Uploaded Python 3

File details

Details for the file disco_torch-0.4.0.tar.gz.

File metadata

  • Download URL: disco_torch-0.4.0.tar.gz
  • Upload date:
  • Size: 37.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.10

File hashes

Hashes for disco_torch-0.4.0.tar.gz
Algorithm Hash digest
SHA256 3321770c35bb73103db441ddf2313d3d9388639c7e6ddeaad0dd3543721d9b0d
MD5 be08475a002ed7de8ba33566a5d2690b
BLAKE2b-256 095270999839292343f588ab21e7b5d13d6b87ca917b5be7c5caa13c99e3560f

See more details on using hashes here.

File details

Details for the file disco_torch-0.4.0-py3-none-any.whl.

File metadata

  • Download URL: disco_torch-0.4.0-py3-none-any.whl
  • Upload date:
  • Size: 34.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.10

File hashes

Hashes for disco_torch-0.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 5a49e0d2da941e59385512b16fe8e924000799135c702955be633002c937d68b
MD5 deda0a0da0f341b55950d0349d5de180
BLAKE2b-256 5b53e134caba596422aac8e0387863aab6e1ef001d45675721a5affdbd271b31

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page