Skip to main content

Robotic Environments with jaX (REX)

Project description

Rex: Robotic Environments with jaX

license PEP8 codestyle

Rex is a JAX-powered framework for sim-to-real robotics.

Key features:

  • Graph-based design: Model asynchronous systems with nodes for sensing, actuation, and computation.
  • Latency-aware modeling: Simulate delay effects for hardware, computation, and communication channels.
  • Real-time and parallelized runtimes: Run real-world experiments or accelerated parallelized simulations.
  • Seamless integration with JAX: Utilize JAX's autodiff, JIT compilation, and GPU/TPU acceleration.
  • System identification tools: Estimate dynamics and delays directly from real-world data.
  • Modular and extensible: Compatible with various simulation engines (e.g., Brax, MuJoCo).
  • Unified sim2real pipeline: Train delay-aware policies in simulation and deploy them on real-world systems.

Sim-to-Real Workflow

  1. Interface Real Systems: Define nodes for sensors, actuators, and computation to represent real-world systems.
  2. Build Simulation: Swap real-world nodes with simulated ones (e.g., physics engines, motor dynamics).
  3. System Identification: Estimate system dynamics and delays from real-world data.
  4. Policy Training: Train delay-aware policies in simulation, accounting for realistic dynamics and delays.
  5. Evaluation: Evaluate trained policies on the real-world system, and iterate on the design.

Installation

pip install rex-lib

Requires Python 3.9+ and JAX 0.4.30+.

Documentation

Available at https://bheijden.github.io/rex/.

Quick example

Here's a simple example of a pendulum system. The real-world system is defined with nodes interfacing hardware for sensing, actuation:

from rex.asynchronous import AsyncGraph
from rex.examples.pendulum import Actuator, Agent, Sensor

sensor = Sensor(rate=50)        # 50 Hz sampling rate
agent = Agent(rate=30)          # 30 Hz policy execution rate
actuator = Actuator(rate=50)    # 50 Hz control rate
nodes = dict(sensor=sensor, agent=agent, actuator=actuator)

agent.connect(sensor)       # Agent receives sensor data
actuator.connect(agent)     # Actuator receives agent commands
graph = AsyncGraph(nodes, agent) # Graph for real-world execution

graph_state = graph.init()  # Initial states of all nodes
graph.warmup(graph_state)   # Jit-compiles the graph (only once).
for _ in range(100):        # Run the graph for 100 steps
    graph_state = graph.run(graph_state) # Run for one step
graph.stop()                # Stop asynchronous nodes
data = graph.get_record()   # Get recorded data from the graph

In simulation, we replace the hardware-interfacing nodes with simulated ones, add delay models, and add a physics simulation node:

from distrax import Normal
from rex.constants import Clock, RealTimeFactor
from rex.asynchronous import AsyncGraph
from rex.examples.pendulum import SimActuator, Agent, SimSensor, BraxWorld

sensor = SimSensor(rate=50, delay_dist=Normal(0.01, 0.001))     # Process delay
agent = Agent(rate=30, delay_dist=Normal(0.02, 0.005))          # Computational delay
actuator = SimActuator(rate=50, delay_dist=Normal(0.01, 0.001)) # Process delay
world = BraxWorld(rate=100)  # 100 Hz physics simulation
nodes = dict(sensor=sensor, agent=agent, actuator=actuator, world=world)

sensor.connect(world, delay_dist=Normal(0.001, 0.001)) # Sensor delay
agent.connect(sensor, delay_dist=Normal(0.001, 0.001)) # Communication delay
actuator.connect(agent, delay_dist=Normal(0.001, 0.001)) # Communication delay
world.connect(actuator, delay_dist=Normal(0.001, 0.001), # Actuator delay
              skip=True) # Breaks algebraic loop in the graph
graph = AsyncGraph(nodes, agent,
                   clock=Clock.SIMULATED, # Simulates based on delay_dist
                   real_time_factor=RealTimeFactor.FAST_AS_POSSIBLE)

graph_state = graph.init()  # Initial states of all nodes
graph.warmup(graph_state)   # Jit-compiles the graph
for _ in range(100):        # Run the graph for 100 steps
    graph_state = graph.run(graph_state) # Run for one step
graph.stop()                # Stop asynchronous nodes
data = graph.get_record()   # Get recorded data from the graph

Nodes are defined using JAX's PyTrees:

from rex.node import BaseNode

class Agent(BaseNode):
    def init_params(self, rng=None, graph_state=None):
        return SomePyTree(a=..., b=...)

    def init_state(self, rng=None, graph_state=None):
        return SomePyTree(x1=..., x2=...)

    def init_output(self, rng=None, graph_state=None):
        return SomePyTree(y1=..., y2=...)
    
    # Jit-compiled via graph.warmup for faster execution
    def step(self, step_state): # Called at Node's rate
        ss = step_state  # Shorten name
        # Read params, and current state
        params, state = ss.params, ss.state
        # Current episode, sequence, timestamp
        eps, seq, ts = ss.eps, ss.seq, ss.ts
        # Grab the data, and I/O timestamps
        cam = ss.inputs["sensor"] # Received messages 
        cam.data, cam.ts_send, cam.ts_recv
        ... # Some computation for new_state, output
        new_state = SomePyTree(x1=..., x2=...)
        output = SomePyTree(y1=..., y2=...)
        # Update step_state for next step call
        new_ss = ss.replace(state=new_state)
        return new_ss, output # Sends output

Next steps

If this quick start has got you interested, then have a look at the sim2real.ipynb notebook for an example of a sim-to-real workflow using Rex.

Citation

If you are using rex for your scientific publications, please cite:

@article{heijden2024rex,
  title={{REX: GPU-Accelerated Sim2Real Framework with Delay and Dynamics Estimation}},
  author={van der Heijden, Bas and Kober, Jens and Babuska, Robert and Ferranti, Laura},
  journal={Transactions on Machine Learning Research (TMLR)},
  year={2025}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rex_lib-0.0.12.tar.gz (112.6 kB view details)

Uploaded Source

Built Distribution

rex_lib-0.0.12-py3-none-any.whl (128.0 kB view details)

Uploaded Python 3

File details

Details for the file rex_lib-0.0.12.tar.gz.

File metadata

  • Download URL: rex_lib-0.0.12.tar.gz
  • Upload date:
  • Size: 112.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.12

File hashes

Hashes for rex_lib-0.0.12.tar.gz
Algorithm Hash digest
SHA256 7ea02517d8c137404a1d042e3cf8d26f9a9dd635f2742ed37dccfdd800b591aa
MD5 f7fffd69e10a60ee09e9d52136d4c65c
BLAKE2b-256 f5d3395ee1f2e98834da498647dfb44a30646a7be9472b61ff1886fb834180c4

See more details on using hashes here.

File details

Details for the file rex_lib-0.0.12-py3-none-any.whl.

File metadata

  • Download URL: rex_lib-0.0.12-py3-none-any.whl
  • Upload date:
  • Size: 128.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.12

File hashes

Hashes for rex_lib-0.0.12-py3-none-any.whl
Algorithm Hash digest
SHA256 38a5a315f3b5c08eb631dee47bb0e041af038dba47c9c82b9169cfd18e307291
MD5 60f7e50325903b23080493d7a49a9ad3
BLAKE2b-256 6c356a4af9421f58853424ce49dd32ce3e639e70ec173bde967ce6e32845c231

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page