Skip to main content

Robotic Environments with jaX (REX)

Project description

Rex: Robotic Environments with jaX

license PEP8 codestyle

Rex is a JAX-powered framework for sim-to-real robotics.

Key features:

  • Graph-based design: Model asynchronous systems with nodes for sensing, actuation, and computation.
  • Latency-aware modeling: Simulate delay effects for hardware, computation, and communication channels.
  • Real-time and parallelized runtimes: Run real-world experiments or accelerated parallelized simulations.
  • Seamless integration with JAX: Utilize JAX's autodiff, JIT compilation, and GPU/TPU acceleration.
  • System identification tools: Estimate dynamics and delays directly from real-world data.
  • Modular and extensible: Compatible with various simulation engines (e.g., Brax, MuJoCo).
  • Unified sim2real pipeline: Train delay-aware policies in simulation and deploy them on real-world systems.

Sim-to-Real Workflow

  1. Interface Real Systems: Define nodes for sensors, actuators, and computation to represent real-world systems.
  2. Build Simulation: Swap real-world nodes with simulated ones (e.g., physics engines, motor dynamics).
  3. System Identification: Estimate system dynamics and delays from real-world data.
  4. Policy Training: Train delay-aware policies in simulation, accounting for realistic dynamics and delays.
  5. Evaluation: Evaluate trained policies on the real-world system, and iterate on the design.

Installation

pip install rex-lib

Requires Python 3.9+ and JAX 0.4.30+.

Documentation

Available at https://bheijden.github.io/rex/.

Quick example

Here's a simple example of a pendulum system. The real-world system is defined with nodes interfacing hardware for sensing, actuation:

from rex.asynchronous import AsyncGraph
from rex.examples.pendulum import Actuator, Agent, Sensor

sensor = Sensor(rate=50)        # 50 Hz sampling rate
agent = Agent(rate=30)          # 30 Hz policy execution rate
actuator = Actuator(rate=50)    # 50 Hz control rate
nodes = dict(sensor=sensor, agent=agent, actuator=actuator)

agent.connect(sensor)       # Agent receives sensor data
actuator.connect(agent)     # Actuator receives agent commands
graph = AsyncGraph(nodes, agent) # Graph for real-world execution

graph_state = graph.init()  # Initial states of all nodes
graph.warmup(graph_state)   # Jit-compiles the graph (only once).
for _ in range(100):        # Run the graph for 100 steps
    graph_state = graph.run(graph_state) # Run for one step
graph.stop()                # Stop asynchronous nodes
data = graph.get_record()   # Get recorded data from the graph

In simulation, we replace the hardware-interfacing nodes with simulated ones, add delay models, and add a physics simulation node:

from distrax import Normal
from rex.constants import Clock, RealTimeFactor
from rex.asynchronous import AsyncGraph
from rex.examples.pendulum import SimActuator, Agent, SimSensor, BraxWorld

sensor = SimSensor(rate=50, delay_dist=Normal(0.01, 0.001))     # Process delay
agent = Agent(rate=30, delay_dist=Normal(0.02, 0.005))          # Computational delay
actuator = SimActuator(rate=50, delay_dist=Normal(0.01, 0.001)) # Process delay
world = BraxWorld(rate=100)  # 100 Hz physics simulation
nodes = dict(sensor=sensor, agent=agent, actuator=actuator, world=world)

sensor.connect(world, delay_dist=Normal(0.001, 0.001)) # Sensor delay
agent.connect(sensor, delay_dist=Normal(0.001, 0.001)) # Communication delay
actuator.connect(agent, delay_dist=Normal(0.001, 0.001)) # Communication delay
world.connect(actuator, delay_dist=Normal(0.001, 0.001), # Actuator delay
              skip=True) # Breaks algebraic loop in the graph
graph = AsyncGraph(nodes, agent,
                   clock=Clock.SIMULATED, # Simulates based on delay_dist
                   real_time_factor=RealTimeFactor.FAST_AS_POSSIBLE)

graph_state = graph.init()  # Initial states of all nodes
graph.warmup(graph_state)   # Jit-compiles the graph
for _ in range(100):        # Run the graph for 100 steps
    graph_state = graph.run(graph_state) # Run for one step
graph.stop()                # Stop asynchronous nodes
data = graph.get_record()   # Get recorded data from the graph

Nodes are defined using JAX's PyTrees:

from rex.node import BaseNode

class Agent(BaseNode):
    def init_params(self, rng=None, graph_state=None):
        return SomePyTree(a=..., b=...)

    def init_state(self, rng=None, graph_state=None):
        return SomePyTree(x1=..., x2=...)

    def init_output(self, rng=None, graph_state=None):
        return SomePyTree(y1=..., y2=...)
    
    # Jit-compiled via graph.warmup for faster execution
    def step(self, step_state): # Called at Node's rate
        ss = step_state  # Shorten name
        # Read params, and current state
        params, state = ss.params, ss.state
        # Current episode, sequence, timestamp
        eps, seq, ts = ss.eps, ss.seq, ss.ts
        # Grab the data, and I/O timestamps
        cam = ss.inputs["sensor"] # Received messages 
        cam.data, cam.ts_send, cam.ts_recv
        ... # Some computation for new_state, output
        new_state = SomePyTree(x1=..., x2=...)
        output = SomePyTree(y1=..., y2=...)
        # Update step_state for next step call
        new_ss = ss.replace(state=new_state)
        return new_ss, output # Sends output

Next steps

If this quick start has got you interested, then have a look at the sim2real.ipynb notebook for an example of a sim-to-real workflow using Rex.

Citation

If you are using rex for your scientific publications, please cite:

@article{heijden2024rex,
  title={{REX: GPU-Accelerated Sim2Real Framework with Delay and Dynamics Estimation}},
  author={van der Heijden, Bas and Kober, Jens and Babuska, Robert and Ferranti, Laura},
  journal={Transactions on Machine Learning Research (TMLR)},
  year={2025}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rex_lib-0.0.10.tar.gz (112.6 kB view details)

Uploaded Source

Built Distribution

rex_lib-0.0.10-py3-none-any.whl (128.0 kB view details)

Uploaded Python 3

File details

Details for the file rex_lib-0.0.10.tar.gz.

File metadata

  • Download URL: rex_lib-0.0.10.tar.gz
  • Upload date:
  • Size: 112.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.11

File hashes

Hashes for rex_lib-0.0.10.tar.gz
Algorithm Hash digest
SHA256 00407860321ac209948a732a9c0babd305b7d0ba8b0b367f12d5ed7c59210579
MD5 5e9953ffa5b8662588f9aec4bca36ef5
BLAKE2b-256 33d7c5b62d4626bb91953fee71c167bda91bd2486bc772c59d1f740e649ff44b

See more details on using hashes here.

File details

Details for the file rex_lib-0.0.10-py3-none-any.whl.

File metadata

  • Download URL: rex_lib-0.0.10-py3-none-any.whl
  • Upload date:
  • Size: 128.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.11

File hashes

Hashes for rex_lib-0.0.10-py3-none-any.whl
Algorithm Hash digest
SHA256 a1f9aabcefdac18a9882c21187784ad97cde6b17496530227d44f6cbe256dc27
MD5 73c498ccf2ea6871a4be6dd02d52affe
BLAKE2b-256 2fd84281cbfce0f88a9be82a521df9964ded83d9392f997e14ebf91d6893330d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page