Skip to main content

PyTorch port of DeepMind's Disco103 meta-learned RL update rule

Project description

disco-torch

A PyTorch port of DeepMind's Disco103 — the meta-learned reinforcement learning update rule from Discovering State-of-the-art Reinforcement Learning Algorithms (Nature, 2025).

What is DiscoRL?

Instead of hand-crafted loss functions like PPO or GRPO, DiscoRL uses a small LSTM neural network (the "meta-network") that generates loss targets for RL agents. Given a rollout of agent experience — policy logits, rewards, advantages, auxiliary predictions — the meta-network outputs target distributions. The agent then minimizes KL divergence between its outputs and these learned targets.

The Disco103 checkpoint (754,778 parameters) was meta-trained by DeepMind across thousands of Atari-like environments. It generalizes as a drop-in update rule for new tasks — no reward shaping, no hyperparameter-specific loss design.

Why a PyTorch port?

The original implementation uses JAX + Haiku. This port enables using Disco103 in PyTorch training pipelines without any JAX dependency at inference time.

Installation

pip install disco-torch

With optional extras:

pip install disco-torch[hub]       # HuggingFace Hub weight downloads
pip install disco-torch[examples]  # gymnasium for running examples
pip install disco-torch[dev]       # pytest + all extras for development

Weights

Option 1 — Download from HuggingFace Hub (requires pip install disco-torch[hub]):

from disco_torch import load_disco103_weights

rule = DiscoUpdateRule()
load_disco103_weights(rule)  # auto-downloads from HuggingFace Hub

Option 2 — Manual download from the disco_rl repo:

cp path/to/disco_103.npz weights/
load_disco103_weights(rule, "weights/disco_103.npz")

Quick start

import torch
from disco_torch import DiscoUpdateRule, UpdateRuleInputs, load_disco103_weights

# Load the meta-network with pretrained weights
rule = DiscoUpdateRule()
load_disco103_weights(rule, "weights/disco_103.npz")

# Initialize meta-RNN state (persists across training steps)
state = rule.meta_net.initial_meta_rnn_state()

# Run the meta-network on a rollout
with torch.no_grad():
    meta_out, new_state = rule.meta_net(inputs, state)
    # meta_out["pi"]  — policy loss targets  [T, B, A]
    # meta_out["y"]   — value loss targets   [T, B, 600]
    # meta_out["z"]   — auxiliary loss targets [T, B, 600]

Full training loop

# At each learner step:
meta_out, new_meta_state = rule.unroll_meta_net(
    rollout, agent_params, meta_state, unroll_fn, hyper_params
)

# Compute agent loss (KL divergence against meta-network targets)
loss, logs = rule.agent_loss(rollout, meta_out, hyper_params)

# Value function loss (no meta-gradient)
value_loss, value_logs = rule.agent_loss_no_meta(rollout, meta_out, hyper_params)

Architecture

Outer (per-trajectory):
  y_net           MLP [600 -> 16 -> 1]         Value prediction embedding
  z_net           MLP [600 -> 16 -> 1]         Auxiliary prediction embedding
  policy_net      Conv1dNet [9 -> 16 -> 2]     Action-conditional embedding
  trajectory_rnn  LSTM(27, 256)                Reverse-unrolled over trajectory
  state_gate      Linear(128 -> 256)           Multiplicative gate from meta-RNN
  y_head / z_head Linear(256 -> 600)           Loss targets for y and z
  pi_conv + head  Conv1dNet [258 -> 16] -> 1   Policy loss target (per action)

Meta-RNN (per-lifetime):
  Separate y/z/policy nets, input MLP(29 -> 16), LSTMCell(16, 128)

The outer network processes each trajectory with a reverse-unrolled LSTM. The meta-RNN operates at a slower timescale — it sees batch-time averages and modulates the outer network via a multiplicative gate. This two-level architecture lets the update rule adapt its behavior over an agent's lifetime.

End-to-end example: Tiny Transformer

The flagship example trains a 1-layer causal transformer (~141K params) to generate strictly increasing digit sequences, using Disco103 as the RL update rule. No supervised data — the agent learns purely from a per-token reward signal.

Open in Google Colab (runs on free T4 GPU)

Using curriculum learning (sequence lengths 4→6→8), the agent achieves 86% strictly increasing on 8-token sequences, discovering sequences like [0, 1, 5, 6, 5, 7, 8, 9]. See experiments.md for the full research log.

# Standalone script (local GPU or CPU)
python examples/tinylm_disco.py --weights weights/disco_103.npz

A simpler CartPole example is also provided for API reference:

python examples/cartpole_disco.py --weights weights/disco_103.npz

Note: Disco103 was meta-trained on 103 complex environments (Atari, ProcGen, DMLab-30). Both CartPole and token generation are outside this distribution. See experiments.md for analysis of how this affects transfer.

Package structure

disco_torch/
  __init__.py          Public API
  types.py             Dataclasses: UpdateRuleInputs, MetaNetInputOption, ValueOuts, etc.
  transforms.py        Input transforms and construct_input()
  meta_net.py          DiscoMetaNet — the full LSTM meta-network
  update_rule.py       DiscoUpdateRule — meta-net + value computation + loss
  value_utils.py       V-trace, TD-error, advantage estimation, Q-values
  utils.py             batch_lookup, signed_logp1, 2-hot encoding, EMA
  load_weights.py      Maps JAX/Haiku NPZ keys -> PyTorch modules

examples/
  tinylm_disco.py          Train a tiny transformer with Disco103 (standalone)
  tinylm_disco_colab.ipynb Google Colab notebook (recommended)
  cartpole_disco.py        CartPole API reference example

scripts/
  inspect_disco103.py      Print NPZ weight names and shapes
  validate_against_jax.py  Numerical comparison: PyTorch vs JAX reference

tests/
  test_utils.py            Unit tests for utility functions
  test_building_blocks.py  Unit tests for network building blocks
  test_meta_net.py         Snapshot tests for meta-network forward pass

Numerical validation

All outputs match the JAX reference implementation within float32 precision:

Output Max diff Status
pi (policy targets) < 1.3e-06 PASS
y (value targets) < 1.3e-06 PASS
z (auxiliary targets) < 1.3e-06 PASS
meta_input_emb < 1.3e-06 PASS
meta_rnn_h < 1.3e-06 PASS

To run the test suite (no JAX required):

pip install disco-torch[dev]
pytest

To run JAX cross-validation (requires JAX + disco_rl):

pip install disco_rl jax dm-haiku rlax distrax
python scripts/validate_against_jax.py

Key implementation details

  • HaikuLSTMCell: Haiku uses gate order [i, g, f, o] with a +1 forget gate bias, vs PyTorch's [i, f, g, o]. This is handled by a custom LSTM cell.
  • Weight mapping: The 42 JAX/Haiku parameters have nested path names (e.g., lstm/~/meta_lstm/~unroll/mlp_2/~/linear_0/w). load_weights.py maps every one to the correct PyTorch module.
  • Conv1dBlock: Each block concatenates per-action features with their mean across actions before the convolution — matching the JAX implementation's broadcast pattern.
  • Value utilities: V-trace, Retrace-style Q-value estimation, signed hyperbolic transforms, and 2-hot categorical encoding are all ported.

Requirements

  • Python >= 3.11
  • PyTorch >= 2.0
  • NumPy >= 1.24

License

Apache 2.0 — same as the original disco_rl.

Citation

If you use this port, please cite the original paper:

@article{oh2025disco,
  title={Discovering State-of-the-art Reinforcement Learning Algorithms},
  author={Oh, Junhyuk and Farquhar, Greg and Kemaev, Iurii and Calian, Dan A. and Hessel, Matteo and Zintgraf, Luisa and Singh, Satinder and van Hasselt, Hado and Silver, David},
  journal={Nature},
  volume={648},
  pages={312--319},
  year={2025},
  doi={10.1038/s41586-025-09761-x}
}

Acknowledgments

This is a community port of google-deepmind/disco_rl. All credit for the algorithm, architecture, and pretrained weights goes to the original authors.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

disco_torch-0.1.0.tar.gz (28.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

disco_torch-0.1.0-py3-none-any.whl (24.9 kB view details)

Uploaded Python 3

File details

Details for the file disco_torch-0.1.0.tar.gz.

File metadata

  • Download URL: disco_torch-0.1.0.tar.gz
  • Upload date:
  • Size: 28.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.10

File hashes

Hashes for disco_torch-0.1.0.tar.gz
Algorithm Hash digest
SHA256 266e204879bc8aed95b2b365741355fe657ea2ab61a50f2ea24105e655f66fed
MD5 660ef2da29bc5c2329cdd5d2309e80bc
BLAKE2b-256 00f042606ffc7549b063ba03d863eabed9849d94ebab3b447c9c1c914c1b2234

See more details on using hashes here.

File details

Details for the file disco_torch-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: disco_torch-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 24.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.10

File hashes

Hashes for disco_torch-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b45ad9f3fdcaff14f93e614099a9db2faeeb28926d1821d7a60e9165f1ef45f0
MD5 eb62bdd54740944cadd8e7188a174484
BLAKE2b-256 56b84bb6d0b4fb78156d5782857a5cd467c969c7fc9f6d64f5100d63b02bb0bb

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page