PyTorch port of DeepMind's Disco103 meta-learned RL update rule
Project description
disco-torch
A PyTorch port of DeepMind's Disco103 — the meta-learned reinforcement learning update rule from Discovering State-of-the-art Reinforcement Learning Algorithms (Nature, 2025).
What is DiscoRL?
Instead of hand-crafted loss functions like PPO or GRPO, DiscoRL uses a small LSTM neural network (the "meta-network") that generates loss targets for RL agents. Given a rollout of agent experience — policy logits, rewards, advantages, auxiliary predictions — the meta-network outputs target distributions. The agent then minimizes KL divergence between its outputs and these learned targets.
The Disco103 checkpoint (754,778 parameters) was meta-trained by DeepMind across thousands of Atari-like environments. It generalizes as a drop-in update rule for new tasks — no reward shaping, no hyperparameter-specific loss design.
Why a PyTorch port?
The original implementation uses JAX + Haiku. This port enables using Disco103 in PyTorch training pipelines without any JAX dependency at inference time.
Installation
pip install disco-torch
With optional extras:
pip install disco-torch[hub] # HuggingFace Hub weight downloads
pip install disco-torch[examples] # gymnasium for running examples
pip install disco-torch[dev] # pytest + all extras for development
Weights
Option 1 — Download from HuggingFace Hub (requires pip install disco-torch[hub]):
from disco_torch import load_disco103_weights
rule = DiscoUpdateRule()
load_disco103_weights(rule) # auto-downloads from HuggingFace Hub
Option 2 — Manual download from the disco_rl repo:
cp path/to/disco_103.npz weights/
load_disco103_weights(rule, "weights/disco_103.npz")
Quick start
import torch
from disco_torch import DiscoUpdateRule, UpdateRuleInputs, load_disco103_weights
# Load the meta-network with pretrained weights
rule = DiscoUpdateRule()
load_disco103_weights(rule, "weights/disco_103.npz")
# Initialize meta-RNN state (persists across training steps)
state = rule.meta_net.initial_meta_rnn_state()
# Run the meta-network on a rollout
with torch.no_grad():
meta_out, new_state = rule.meta_net(inputs, state)
# meta_out["pi"] — policy loss targets [T, B, A]
# meta_out["y"] — value loss targets [T, B, 600]
# meta_out["z"] — auxiliary loss targets [T, B, 600]
Full training loop
# At each learner step:
meta_out, new_meta_state = rule.unroll_meta_net(
rollout, agent_params, meta_state, unroll_fn, hyper_params
)
# Compute agent loss (KL divergence against meta-network targets)
loss, logs = rule.agent_loss(rollout, meta_out, hyper_params)
# Value function loss (no meta-gradient)
value_loss, value_logs = rule.agent_loss_no_meta(rollout, meta_out, hyper_params)
Architecture
Outer (per-trajectory):
y_net MLP [600 -> 16 -> 1] Value prediction embedding
z_net MLP [600 -> 16 -> 1] Auxiliary prediction embedding
policy_net Conv1dNet [9 -> 16 -> 2] Action-conditional embedding
trajectory_rnn LSTM(27, 256) Reverse-unrolled over trajectory
state_gate Linear(128 -> 256) Multiplicative gate from meta-RNN
y_head / z_head Linear(256 -> 600) Loss targets for y and z
pi_conv + head Conv1dNet [258 -> 16] -> 1 Policy loss target (per action)
Meta-RNN (per-lifetime):
Separate y/z/policy nets, input MLP(29 -> 16), LSTMCell(16, 128)
The outer network processes each trajectory with a reverse-unrolled LSTM. The meta-RNN operates at a slower timescale — it sees batch-time averages and modulates the outer network via a multiplicative gate. This two-level architecture lets the update rule adapt its behavior over an agent's lifetime.
End-to-end example: Tiny Transformer
The flagship example trains a 1-layer causal transformer (~141K params) to generate strictly increasing digit sequences, using Disco103 as the RL update rule. No supervised data — the agent learns purely from a per-token reward signal.
Open in Google Colab (runs on free T4 GPU)
Using curriculum learning (sequence lengths 4→6→8), the agent achieves 86% strictly increasing on 8-token sequences, discovering sequences like [0, 1, 5, 6, 5, 7, 8, 9]. See experiments.md for the full research log.
# Standalone script (local GPU or CPU)
python examples/tinylm_disco.py --weights weights/disco_103.npz
A simpler CartPole example is also provided for API reference:
python examples/cartpole_disco.py --weights weights/disco_103.npz
Note: Disco103 was meta-trained on 103 complex environments (Atari, ProcGen, DMLab-30). Both CartPole and token generation are outside this distribution. See
experiments.mdfor analysis of how this affects transfer.
Package structure
disco_torch/
__init__.py Public API
types.py Dataclasses: UpdateRuleInputs, MetaNetInputOption, ValueOuts, etc.
transforms.py Input transforms and construct_input()
meta_net.py DiscoMetaNet — the full LSTM meta-network
update_rule.py DiscoUpdateRule — meta-net + value computation + loss
value_utils.py V-trace, TD-error, advantage estimation, Q-values
utils.py batch_lookup, signed_logp1, 2-hot encoding, EMA
load_weights.py Maps JAX/Haiku NPZ keys -> PyTorch modules
examples/
tinylm_disco.py Train a tiny transformer with Disco103 (standalone)
tinylm_disco_colab.ipynb Google Colab notebook (recommended)
cartpole_disco.py CartPole API reference example
scripts/
inspect_disco103.py Print NPZ weight names and shapes
validate_against_jax.py Numerical comparison: PyTorch vs JAX reference
tests/
test_utils.py Unit tests for utility functions
test_building_blocks.py Unit tests for network building blocks
test_meta_net.py Snapshot tests for meta-network forward pass
Numerical validation
All outputs match the JAX reference implementation within float32 precision:
| Output | Max diff | Status |
|---|---|---|
| pi (policy targets) | < 1.3e-06 | PASS |
| y (value targets) | < 1.3e-06 | PASS |
| z (auxiliary targets) | < 1.3e-06 | PASS |
| meta_input_emb | < 1.3e-06 | PASS |
| meta_rnn_h | < 1.3e-06 | PASS |
To run the test suite (no JAX required):
pip install disco-torch[dev]
pytest
To run JAX cross-validation (requires JAX + disco_rl):
pip install disco_rl jax dm-haiku rlax distrax
python scripts/validate_against_jax.py
Key implementation details
- HaikuLSTMCell: Haiku uses gate order
[i, g, f, o]with a +1 forget gate bias, vs PyTorch's[i, f, g, o]. This is handled by a custom LSTM cell. - Weight mapping: The 42 JAX/Haiku parameters have nested path names (e.g.,
lstm/~/meta_lstm/~unroll/mlp_2/~/linear_0/w).load_weights.pymaps every one to the correct PyTorch module. - Conv1dBlock: Each block concatenates per-action features with their mean across actions before the convolution — matching the JAX implementation's broadcast pattern.
- Value utilities: V-trace, Retrace-style Q-value estimation, signed hyperbolic transforms, and 2-hot categorical encoding are all ported.
Requirements
- Python >= 3.11
- PyTorch >= 2.0
- NumPy >= 1.24
License
Apache 2.0 — same as the original disco_rl.
Citation
If you use this port, please cite the original paper:
@article{oh2025disco,
title={Discovering State-of-the-art Reinforcement Learning Algorithms},
author={Oh, Junhyuk and Farquhar, Greg and Kemaev, Iurii and Calian, Dan A. and Hessel, Matteo and Zintgraf, Luisa and Singh, Satinder and van Hasselt, Hado and Silver, David},
journal={Nature},
volume={648},
pages={312--319},
year={2025},
doi={10.1038/s41586-025-09761-x}
}
Acknowledgments
This is a community port of google-deepmind/disco_rl. All credit for the algorithm, architecture, and pretrained weights goes to the original authors.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file disco_torch-0.1.0.tar.gz.
File metadata
- Download URL: disco_torch-0.1.0.tar.gz
- Upload date:
- Size: 28.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
266e204879bc8aed95b2b365741355fe657ea2ab61a50f2ea24105e655f66fed
|
|
| MD5 |
660ef2da29bc5c2329cdd5d2309e80bc
|
|
| BLAKE2b-256 |
00f042606ffc7549b063ba03d863eabed9849d94ebab3b447c9c1c914c1b2234
|
File details
Details for the file disco_torch-0.1.0-py3-none-any.whl.
File metadata
- Download URL: disco_torch-0.1.0-py3-none-any.whl
- Upload date:
- Size: 24.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b45ad9f3fdcaff14f93e614099a9db2faeeb28926d1821d7a60e9165f1ef45f0
|
|
| MD5 |
eb62bdd54740944cadd8e7188a174484
|
|
| BLAKE2b-256 |
56b84bb6d0b4fb78156d5782857a5cd467c969c7fc9f6d64f5100d63b02bb0bb
|