Skip to main content

Unified Python interface for world models in reinforcement learning

Project description

WorldFlux

WorldFlux Logo

Unified Interface for World Models in Reinforcement Learning

One API. Multiple Architectures. Infinite Imagination.

Open In Colab Hugging Face Spaces Discord PyPI

License: Apache 2.0 Python 3.10+ PyTorch 2.0+ Ruff Type Checked: mypy


WorldFlux provides a unified Python interface for world models used in reinforcement learning. Starting with efficient latent-space models (DreamerV3, TD-MPC2), with plans to support diverse architectures including autoregressive and diffusion-based world models.

Demo

Imagination Rollout

World model imagines future frames from a single observation:

Imagination Rollout

Left: Real game frames | Right: Model's imagination

Latent Space Dynamics

Visualization of how different episodes traverse the learned latent space:

Latent Space

Features

  • Unified API: Common interface across DreamerV3, TD-MPC2, and more
  • Simple Usage: One-liner model creation with create_world_model()
  • Training Infrastructure: Complete training loop with callbacks, checkpointing, and logging
  • Type Safe: Full type annotations and mypy compatibility

Architecture

graph LR
    subgraph Input
        A[Observation]
    end

    subgraph WorldModel["World Model"]
        B[Encoder]
        C[LatentState]
        D[Dynamics<br/>RSSM/MLP]
        E[Decoder]
    end

    subgraph Output
        F[Predictions<br/>obs, reward, continue]
    end

    A --> B
    B --> C
    C --> D
    D --> C
    C --> E
    E --> F

    style C fill:#e1f5fe
    style D fill:#fff3e0

Key Concepts:

  • Encoder: Maps observations to latent states
  • LatentState: Compact representation (deterministic + stochastic components)
  • Dynamics: Predicts next latent state from current state + action
  • Decoder: Reconstructs observations/rewards from latent states

Model Comparison

Feature DreamerV3 TD-MPC2
Input Type Images or State State vectors
Latent Space Categorical (discrete) SimNorm (continuous)
Architecture RSSM Implicit MLP
Decoder Yes No (implicit)
Best For Atari, visual tasks MuJoCo, robotics
Planning Policy rollouts MPC with Q-ensemble

Installation

From Source (recommended)

git clone https://github.com/worldflux/Worldflux.git
cd worldflux
pip install -e "."

# With training dependencies
pip install -e ".[training]"

# With all optional dependencies
pip install -e ".[all]"

# For development
pip install -e ".[dev]"

From PyPI

pip install worldflux

Quick Start

Create a Model

from worldflux import create_world_model

# DreamerV3 (image observations)
model = create_world_model("dreamerv3:size12m")

# TD-MPC2 (state observations)
model = create_world_model("tdmpc2:5m", obs_shape=(39,), action_dim=6)

Imagination Rollout

import torch

# Encode initial observation
obs = torch.randn(1, 3, 64, 64)
state = model.encode(obs)

# Imagine 15 steps into the future
actions = torch.randn(15, 1, 4)  # [horizon, batch, action_dim]
trajectory = model.imagine(state, actions)

# Access predictions
print(f"Predicted rewards: {trajectory.rewards.shape}")
print(f"Continue probs: {trajectory.continues.shape}")

Train a Model

from worldflux import create_world_model
from worldflux.training import train, ReplayBuffer

# Create model
model = create_world_model("dreamerv3:size12m", obs_shape=(4,), action_dim=2)

# Load data
buffer = ReplayBuffer.load("trajectories.npz")

# Train (one-liner)
trained_model = train(model, buffer, total_steps=50_000)

# Save
trained_model.save_pretrained("./my_model")

Full Training Control

from worldflux import create_world_model
from worldflux.training import Trainer, TrainingConfig, ReplayBuffer

model = create_world_model("tdmpc2:5m", obs_shape=(39,), action_dim=6)
buffer = ReplayBuffer(capacity=100_000, obs_shape=(39,), action_dim=6)

config = TrainingConfig(
    total_steps=100_000,
    batch_size=256,
    learning_rate=1e-4,
)

trainer = Trainer(model, config)
trainer.train(buffer)

Available Models

DreamerV3

Preset Parameters Description
dreamerv3:size12m 12M Small, fast training
dreamerv3:size25m 25M Balanced
dreamerv3:size50m 50M Standard
dreamerv3:size100m 100M Large
dreamerv3:size200m 200M Maximum capacity

TD-MPC2

Preset Parameters Description
tdmpc2:5m 5M Small, fast
tdmpc2:19m 19M Balanced
tdmpc2:48m 48M Large
tdmpc2:317m 317M Maximum capacity

API Reference

Core Methods

All world models implement the WorldModel protocol:

# Encode observation to latent state
state = model.encode(obs)

# Predict next state (imagination, no observation)
next_state = model.predict(state, action)

# Update state with observation (posterior)
next_state = model.observe(state, action, obs)

# Decode latent state to predictions
predictions = model.decode(state)  # {"obs", "reward", "continue"}

# Multi-step imagination rollout
trajectory = model.imagine(initial_state, actions)

# Compute training losses
losses = model.compute_loss(batch)  # {"loss", "kl", "reconstruction", ...}

Training API

from worldflux.training import (
    Trainer,
    TrainingConfig,
    ReplayBuffer,
    train,
)

# Configuration
config = TrainingConfig(
    total_steps=100_000,
    batch_size=16,
    sequence_length=50,
    learning_rate=3e-4,
    grad_clip=100.0,
)

# Callbacks
from worldflux.training.callbacks import (
    LoggingCallback,
    CheckpointCallback,
    EarlyStoppingCallback,
    ProgressCallback,
)

Examples

See the examples/ directory:

  • worldflux_quickstart.ipynb - Interactive Colab notebook
  • train_dreamer.py - DreamerV3 training example
  • train_tdmpc2.py - TD-MPC2 training example
  • visualize_imagination.py - Imagination rollout visualization
# Quick test with random data
python examples/train_dreamer.py --test

# Train with real data
python examples/train_dreamer.py --data trajectories.npz --steps 100000

# Collect Atari data
python examples/collect_atari.py --env Breakout --episodes 100

# Train on Atari
python examples/train_atari_dreamer.py --data atari_data.npz --steps 100000

Benchmarks

Results on standard benchmarks:

Environment Model Score Training Steps
Atari Breakout DreamerV3-50M - 200K
MuJoCo HalfCheetah TD-MPC2-19M - 1M
DMControl Walker DreamerV3-25M - 500K

Documentation

Community

Join our Discord to discuss world models, get help, and connect with other researchers and developers.

Security

See SECURITY.md for security considerations, especially regarding loading model checkpoints from untrusted sources.

License

Apache License 2.0 - see LICENSE and NOTICE for details.

Contributing

Contributions are welcome! Please read our Contributing Guide before submitting pull requests.

Citation

If you use this library in your research, please cite:

@software{worldflux,
  title = {WorldFlux: Unified Interface for World Models},
  year = {2026},
  url = {https://github.com/worldflux/Worldflux}
}

Acknowledgments

WorldFlux builds on the excellent research from:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

worldflux-0.1.0.tar.gz (62.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

worldflux-0.1.0-py3-none-any.whl (61.6 kB view details)

Uploaded Python 3

File details

Details for the file worldflux-0.1.0.tar.gz.

File metadata

  • Download URL: worldflux-0.1.0.tar.gz
  • Upload date:
  • Size: 62.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.18

File hashes

Hashes for worldflux-0.1.0.tar.gz
Algorithm Hash digest
SHA256 18e8f2d6720234c8eb2e156a0b241ad9175e468026259f2fce90bdefe5047a20
MD5 fec9c561d96b994c372360266c129fc6
BLAKE2b-256 ffef633978df123a4b4b81a786bd4c8c1d75533518f03a43c8c741d1575a9843

See more details on using hashes here.

File details

Details for the file worldflux-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: worldflux-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 61.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.18

File hashes

Hashes for worldflux-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1b5543a7c94c9f0a9e3708add0fffa8d734c03ee8e3796c6d574a447ed45c2b8
MD5 ff8e46526115a73cf21d4d62d7c42d05
BLAKE2b-256 7231048a319cc3141c9c8efb7eefc12e337080c847c432936aed94b67bbd8d7e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page