Skip to main content

Implementation of RISE, Self-Improving Robot Policy with Compositional World Model

Project description

RISE (wip)

Implementation of RISE, Self-Improving Robot Policy with Compositional World Model

Usage

Here is a complete, minimal example of how to orchestrate self-improvement using RISE.

This involves:

  1. Fine-tuning an action-conditioned Cosmos Dynamics Model.
  2. Initializing a PiZero policy and SigLIP Value Network.
  3. Engaging the RISE loop: Imagination Rollout followed by Policy Finetuning.
import torch

from RISE import (
    RISE,
    CosmosPredictWrapper,
    DynamicsTrainer,
    MockOfflineRoboticFrameDataset
)

from value_network import SigLIPValueNetwork

from pi_zero_pytorch.pi_zero import PiZero, SigLIP as PiZeroSigLIP

# 1. Provide an offline seed dataset (video + proprioception)
dataset = MockOfflineRoboticFrameDataset(num_samples = 10, image_size = 224)

# 2. Initialize and Fine-Tune the Action-Conditioned Dynamics Model
dynamics_model = CosmosPredictWrapper(
    model_name = 'nvidia/Cosmos-1.0-Diffusion-7B-Video2World',
    action_dim = 10,
    action_chunk_len = 8
)

trainer = DynamicsTrainer(
    model = dynamics_model,
    dataset = dataset,
    batch_size = 1,
    lr = 1e-4
)

# Learn system dynamics from the offline dataset
trainer.train(num_steps = 1000)

# 3. Initialize Policy (PiZero) and Value Network Evaluator
policy = PiZero(
    dim = 256,
    num_tokens = 1000,
    dim_action_input = 10,
    dim_joint_state = 14,
    depth = 4,
    pi05 = True,
    vit = PiZeroSigLIP(
        image_size = 224,
        patch_size = 16,
        dim = 256,
        depth = 4,
        heads = 4,
        mlp_dim = 512
    ),
    vit_dim = 256
)

value_model = SigLIPValueNetwork(
    siglip_image_size = 224,
    siglip_patch_size = 16,
    siglip_dim = 256,
    siglip_depth = 4,
    siglip_heads = 4,
    siglip_mlp_dim = 512
)

# 4. Instantiate the RISE Orchestrator
rise = RISE(
    policy = policy,
    dynamics_model = dynamics_model,
    value_model = value_model,
    trajectory_length = 8,
    num_prompt_tokens = 12,
    imagination_steps = 5
)

# 5. Imagination Rollout Stage
# The policy generates actions, the dynamics model predicts the future,
# the value network evaluates the advantage, and experience is stored.
replay_buffer = rise.imagination_rollout(
    seed_dataset = dataset,
    num_episodes = 2,
    batch_size = 1,
    buffer_folder = './rise_experience_buffer'
)

# 6. Self-Improvement Finetuning Stage
# The policy learns to imitate high-advantage imagined trajectories.
rise.finetune_with_advantage_conditioning(
    replay_buffer = replay_buffer,
    num_steps = 1000,
    batch_size = 2,
    lr = 1e-4
)

# 7. Save the Self-Improved Policy
torch.save(rise.policy.state_dict(), './improved_pi05_policy.pt')

Citations

@misc{yang2026riseselfimprovingrobotpolicy,
    title   = {RISE: Self-Improving Robot Policy with Compositional World Model}, 
    author  = {Jiazhi Yang and Kunyang Lin and Jinwei Li and Wencong Zhang and Tianwei Lin and Longyan Wu and Zhizhong Su and Hao Zhao and Ya-Qin Zhang and Li Chen and Ping Luo and Xiangyu Yue and Hongyang Li},
    year    = {2026},
    eprint  = {2602.11075},
    archivePrefix = {arXiv},
    primaryClass = {cs.RO},
    url     = {https://arxiv.org/abs/2602.11075}, 
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rise_pytorch-0.0.1.tar.gz (14.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rise_pytorch-0.0.1-py3-none-any.whl (14.9 kB view details)

Uploaded Python 3

File details

Details for the file rise_pytorch-0.0.1.tar.gz.

File metadata

  • Download URL: rise_pytorch-0.0.1.tar.gz
  • Upload date:
  • Size: 14.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.13

File hashes

Hashes for rise_pytorch-0.0.1.tar.gz
Algorithm Hash digest
SHA256 af610b4688efda5831e3c4838514fbd8f54605339065dc93efe0a67204d7cc3e
MD5 168f11c18a973cd3ad950dee155dce38
BLAKE2b-256 50e65ee0458aa6661616b2d217c6c56997856a2006cd68080c8772d90d191f00

See more details on using hashes here.

File details

Details for the file rise_pytorch-0.0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for rise_pytorch-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 adca7f1fc48b2226352c0aff255028db870ac165dbe67551cb515b488ec4f2a4
MD5 c9a0349a10674753ed3c35769ae18037
BLAKE2b-256 9effca8a7aac5e0712e8677a4b48cf5fb9e6e03ce6020d28764c14d47c85d43d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page