Unified Python interface for world models in reinforcement learning
Project description
WorldFlux
Unified Interface for World Models in Reinforcement Learning
One API. Multiple Architectures. Infinite Imagination.
WorldFlux provides a unified Python interface for world models used in reinforcement learning. Starting with efficient latent-space models (DreamerV3, TD-MPC2), with plans to support diverse architectures including autoregressive and diffusion-based world models.
Demo
Imagination Rollout
World model imagines future frames from a single observation:
Left: Real game frames | Right: Model's imagination
Latent Space Dynamics
Visualization of how different episodes traverse the learned latent space:
Features
- Unified API: Common interface across DreamerV3, TD-MPC2, and more
- Simple Usage: One-liner model creation with
create_world_model() - Training Infrastructure: Complete training loop with callbacks, checkpointing, and logging
- Type Safe: Full type annotations and mypy compatibility
Architecture
graph LR
subgraph Input
A[Observation]
end
subgraph WorldModel["World Model"]
B[Encoder]
C[LatentState]
D[Dynamics<br/>RSSM/MLP]
E[Decoder]
end
subgraph Output
F[Predictions<br/>obs, reward, continue]
end
A --> B
B --> C
C --> D
D --> C
C --> E
E --> F
style C fill:#e1f5fe
style D fill:#fff3e0
Key Concepts:
- Encoder: Maps observations to latent states
- LatentState: Compact representation (deterministic + stochastic components)
- Dynamics: Predicts next latent state from current state + action
- Decoder: Reconstructs observations/rewards from latent states
Model Comparison
| Feature | DreamerV3 | TD-MPC2 |
|---|---|---|
| Input Type | Images or State | State vectors |
| Latent Space | Categorical (discrete) | SimNorm (continuous) |
| Architecture | RSSM | Implicit MLP |
| Decoder | Yes | No (implicit) |
| Best For | Atari, visual tasks | MuJoCo, robotics |
| Planning | Policy rollouts | MPC with Q-ensemble |
Installation
From Source (recommended)
git clone https://github.com/worldflux/Worldflux.git
cd worldflux
pip install -e "."
# With training dependencies
pip install -e ".[training]"
# With all optional dependencies
pip install -e ".[all]"
# For development
pip install -e ".[dev]"
From PyPI
pip install worldflux
Quick Start
Create a Model
from worldflux import create_world_model
# DreamerV3 (image observations)
model = create_world_model("dreamerv3:size12m")
# TD-MPC2 (state observations)
model = create_world_model("tdmpc2:5m", obs_shape=(39,), action_dim=6)
Imagination Rollout
import torch
# Encode initial observation
obs = torch.randn(1, 3, 64, 64)
state = model.encode(obs)
# Imagine 15 steps into the future
actions = torch.randn(15, 1, 4) # [horizon, batch, action_dim]
trajectory = model.imagine(state, actions)
# Access predictions
print(f"Predicted rewards: {trajectory.rewards.shape}")
print(f"Continue probs: {trajectory.continues.shape}")
Train a Model
from worldflux import create_world_model
from worldflux.training import train, ReplayBuffer
# Create model
model = create_world_model("dreamerv3:size12m", obs_shape=(4,), action_dim=2)
# Load data
buffer = ReplayBuffer.load("trajectories.npz")
# Train (one-liner)
trained_model = train(model, buffer, total_steps=50_000)
# Save
trained_model.save_pretrained("./my_model")
Full Training Control
from worldflux import create_world_model
from worldflux.training import Trainer, TrainingConfig, ReplayBuffer
model = create_world_model("tdmpc2:5m", obs_shape=(39,), action_dim=6)
buffer = ReplayBuffer(capacity=100_000, obs_shape=(39,), action_dim=6)
config = TrainingConfig(
total_steps=100_000,
batch_size=256,
learning_rate=1e-4,
)
trainer = Trainer(model, config)
trainer.train(buffer)
Available Models
DreamerV3
| Preset | Parameters | Description |
|---|---|---|
dreamerv3:size12m |
12M | Small, fast training |
dreamerv3:size25m |
25M | Balanced |
dreamerv3:size50m |
50M | Standard |
dreamerv3:size100m |
100M | Large |
dreamerv3:size200m |
200M | Maximum capacity |
TD-MPC2
| Preset | Parameters | Description |
|---|---|---|
tdmpc2:5m |
5M | Small, fast |
tdmpc2:19m |
19M | Balanced |
tdmpc2:48m |
48M | Large |
tdmpc2:317m |
317M | Maximum capacity |
API Reference
Core Methods
All world models implement the WorldModel protocol:
# Encode observation to latent state
state = model.encode(obs)
# Predict next state (imagination, no observation)
next_state = model.predict(state, action)
# Update state with observation (posterior)
next_state = model.observe(state, action, obs)
# Decode latent state to predictions
predictions = model.decode(state) # {"obs", "reward", "continue"}
# Multi-step imagination rollout
trajectory = model.imagine(initial_state, actions)
# Compute training losses
losses = model.compute_loss(batch) # {"loss", "kl", "reconstruction", ...}
Training API
from worldflux.training import (
Trainer,
TrainingConfig,
ReplayBuffer,
train,
)
# Configuration
config = TrainingConfig(
total_steps=100_000,
batch_size=16,
sequence_length=50,
learning_rate=3e-4,
grad_clip=100.0,
)
# Callbacks
from worldflux.training.callbacks import (
LoggingCallback,
CheckpointCallback,
EarlyStoppingCallback,
ProgressCallback,
)
Examples
See the examples/ directory:
worldflux_quickstart.ipynb- Interactive Colab notebooktrain_dreamer.py- DreamerV3 training exampletrain_tdmpc2.py- TD-MPC2 training examplevisualize_imagination.py- Imagination rollout visualization
# Quick test with random data
python examples/train_dreamer.py --test
# Train with real data
python examples/train_dreamer.py --data trajectories.npz --steps 100000
# Collect Atari data
python examples/collect_atari.py --env Breakout --episodes 100
# Train on Atari
python examples/train_atari_dreamer.py --data atari_data.npz --steps 100000
Benchmarks
Results on standard benchmarks:
| Environment | Model | Score | Training Steps |
|---|---|---|---|
| Atari Breakout | DreamerV3-50M | - | 200K |
| MuJoCo HalfCheetah | TD-MPC2-19M | - | 1M |
| DMControl Walker | DreamerV3-25M | - | 500K |
Documentation
- Full Documentation - Comprehensive guides and API reference
- Tutorials - Step-by-step learning
- API Reference - Detailed API docs
Community
Join our Discord to discuss world models, get help, and connect with other researchers and developers.
Security
See SECURITY.md for security considerations, especially regarding loading model checkpoints from untrusted sources.
License
Apache License 2.0 - see LICENSE and NOTICE for details.
Contributing
Contributions are welcome! Please read our Contributing Guide before submitting pull requests.
Citation
If you use this library in your research, please cite:
@software{worldflux,
title = {WorldFlux: Unified Interface for World Models},
year = {2026},
url = {https://github.com/worldflux/Worldflux}
}
Acknowledgments
WorldFlux builds on the excellent research from:
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file worldflux-0.1.0.tar.gz.
File metadata
- Download URL: worldflux-0.1.0.tar.gz
- Upload date:
- Size: 62.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
18e8f2d6720234c8eb2e156a0b241ad9175e468026259f2fce90bdefe5047a20
|
|
| MD5 |
fec9c561d96b994c372360266c129fc6
|
|
| BLAKE2b-256 |
ffef633978df123a4b4b81a786bd4c8c1d75533518f03a43c8c741d1575a9843
|
File details
Details for the file worldflux-0.1.0-py3-none-any.whl.
File metadata
- Download URL: worldflux-0.1.0-py3-none-any.whl
- Upload date:
- Size: 61.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1b5543a7c94c9f0a9e3708add0fffa8d734c03ee8e3796c6d574a447ed45c2b8
|
|
| MD5 |
ff8e46526115a73cf21d4d62d7c42d05
|
|
| BLAKE2b-256 |
7231048a319cc3141c9c8efb7eefc12e337080c847c432936aed94b67bbd8d7e
|