Skip to main content

Improving Transformers World Model for RL

Project description

Improving Transformers World Model - Pytorch (wip)

Implementation of the new SOTA for model based RL, from the paper Improving Transformer World Models for Data-Efficient RL, in Pytorch.

They significantly outperformed DreamerV3 (as well as human experts) with a transformer world model and a less complicated setup, on Craftax (simplified Minecraft environment)

Install

$ pip install improving-transformers-world-model

Usage

import torch

from improving_transformers_world_model import (
    WorldModel
)

world_model = WorldModel(
    image_size = 63,
    patch_size = 7,
    channels = 3,
    transformer = dict(
        dim = 512,
        depth = 4,
        block_size = 81
    ),
    tokenizer = dict(
        dim = 7 * 7 * 3,
        distance_threshold = 0.5
    )
)

state = torch.randn(2, 3, 20, 63, 63) # batch, channels, time, height, width - craftax is 3 channels 63x63, and they used rollout of 20 frames. block size is presumably each image

loss = world_model(state)
loss.backward()

# dream up a trajectory to be mixed with real for training PPO

prompts = state[:, :, :2] # prompt frames

imagined_trajectories = world_model.sample(prompts, time_steps = 20)

assert imagined_trajectories.shape == state.shape

Citations

@inproceedings{Dedieu2025ImprovingTW,
    title   = {Improving Transformer World Models for Data-Efficient RL},
    author  = {Antoine Dedieu and Joseph Ortiz and Xinghua Lou and Carter Wendelken and Wolfgang Lehrach and J. Swaroop Guntupalli and Miguel L{\'a}zaro-Gredilla and Kevin Patrick Murphy},
    year    = {2025},
    url     = {https://api.semanticscholar.org/CorpusID:276107865}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

improving_transformers_world_model-0.0.59.tar.gz (607.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

File details

Details for the file improving_transformers_world_model-0.0.59.tar.gz.

File metadata

File hashes

Hashes for improving_transformers_world_model-0.0.59.tar.gz
Algorithm Hash digest
SHA256 402e09533e5cdab814f26ce3e0be59ce599b4093e0f9b53af63f251a1fe01521
MD5 83866efbf58df5d409003a3ed1891966
BLAKE2b-256 4ab79dcd51a6117a42b9bec068cb7f869e5da37a3327f907e8604ccf8469a05c

See more details on using hashes here.

File details

Details for the file improving_transformers_world_model-0.0.59-py3-none-any.whl.

File metadata

File hashes

Hashes for improving_transformers_world_model-0.0.59-py3-none-any.whl
Algorithm Hash digest
SHA256 d53c61bd22fd8042cfdf092ef0637ab8a9a1a366a1aa1f381d0cec287816acfa
MD5 942bb1c26d8aa9397b93d1fc5c6bc26c
BLAKE2b-256 0b72a8ca34002ef4a7b50eccab2ba644fc66e725776c4597e12722ee655d8853

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page