Robotic Environments with jaX (REX)
Project description
Rex: Robotic Environments with jaX
Rex is a JAX-powered framework for sim-to-real robotics.
Key features:
- Graph-based design: Model asynchronous systems with nodes for sensing, actuation, and computation.
- Latency-aware modeling: Simulate delay effects for hardware, computation, and communication channels.
- Real-time and parallelized runtimes: Run real-world experiments or accelerated parallelized simulations.
- Seamless integration with JAX: Utilize JAX's autodiff, JIT compilation, and GPU/TPU acceleration.
- System identification tools: Estimate dynamics and delays directly from real-world data.
- Modular and extensible: Compatible with various simulation engines (e.g., Brax, MuJoCo).
- Unified sim2real pipeline: Train delay-aware policies in simulation and deploy them on real-world systems.
Sim-to-Real Workflow
- Interface Real Systems: Define nodes for sensors, actuators, and computation to represent real-world systems.
- Build Simulation: Swap real-world nodes with simulated ones (e.g., physics engines, motor dynamics).
- System Identification: Estimate system dynamics and delays from real-world data.
- Policy Training: Train delay-aware policies in simulation, accounting for realistic dynamics and delays.
- Evaluation: Evaluate trained policies on the real-world system, and iterate on the design.
Installation
pip install rex-lib
Requires Python 3.9+ and JAX 0.4.30+.
Documentation
Available at https://bheijden.github.io/rex/.
Quick example
Here's a simple example of a pendulum system. The real-world system is defined with nodes interfacing hardware for sensing, actuation:
from rex.asynchronous import AsyncGraph
from rex.examples.pendulum import Actuator, Agent, Sensor
sensor = Sensor(rate=50) # 50 Hz sampling rate
agent = Agent(rate=30) # 30 Hz policy execution rate
actuator = Actuator(rate=50) # 50 Hz control rate
nodes = dict(sensor=sensor, agent=agent, actuator=actuator)
agent.connect(sensor) # Agent receives sensor data
actuator.connect(agent) # Actuator receives agent commands
graph = AsyncGraph(nodes, agent) # Graph for real-world execution
graph_state = graph.init() # Initial states of all nodes
graph.warmup(graph_state) # Jit-compiles the graph (only once).
for _ in range(100): # Run the graph for 100 steps
graph_state = graph.run(graph_state) # Run for one step
graph.stop() # Stop asynchronous nodes
data = graph.get_record() # Get recorded data from the graph
In simulation, we replace the hardware-interfacing nodes with simulated ones, add delay models, and add a physics simulation node:
from distrax import Normal
from rex.constants import Clock, RealTimeFactor
from rex.asynchronous import AsyncGraph
from rex.examples.pendulum import SimActuator, Agent, SimSensor, BraxWorld
sensor = SimSensor(rate=50, delay_dist=Normal(0.01, 0.001)) # Process delay
agent = Agent(rate=30, delay_dist=Normal(0.02, 0.005)) # Computational delay
actuator = SimActuator(rate=50, delay_dist=Normal(0.01, 0.001)) # Process delay
world = BraxWorld(rate=100) # 100 Hz physics simulation
nodes = dict(sensor=sensor, agent=agent, actuator=actuator, world=world)
sensor.connect(world, delay_dist=Normal(0.001, 0.001)) # Sensor delay
agent.connect(sensor, delay_dist=Normal(0.001, 0.001)) # Communication delay
actuator.connect(agent, delay_dist=Normal(0.001, 0.001)) # Communication delay
world.connect(actuator, delay_dist=Normal(0.001, 0.001), # Actuator delay
skip=True) # Breaks algebraic loop in the graph
graph = AsyncGraph(nodes, agent,
clock=Clock.SIMULATED, # Simulates based on delay_dist
real_time_factor=RealTimeFactor.FAST_AS_POSSIBLE)
graph_state = graph.init() # Initial states of all nodes
graph.warmup(graph_state) # Jit-compiles the graph
for _ in range(100): # Run the graph for 100 steps
graph_state = graph.run(graph_state) # Run for one step
graph.stop() # Stop asynchronous nodes
data = graph.get_record() # Get recorded data from the graph
Nodes are defined using JAX's PyTrees:
from rex.node import BaseNode
class Agent(BaseNode):
def init_params(self, rng=None, graph_state=None):
return SomePyTree(a=..., b=...)
def init_state(self, rng=None, graph_state=None):
return SomePyTree(x1=..., x2=...)
def init_output(self, rng=None, graph_state=None):
return SomePyTree(y1=..., y2=...)
# Jit-compiled via graph.warmup for faster execution
def step(self, step_state): # Called at Node's rate
ss = step_state # Shorten name
# Read params, and current state
params, state = ss.params, ss.state
# Current episode, sequence, timestamp
eps, seq, ts = ss.eps, ss.seq, ss.ts
# Grab the data, and I/O timestamps
cam = ss.inputs["sensor"] # Received messages
cam.data, cam.ts_send, cam.ts_recv
... # Some computation for new_state, output
new_state = SomePyTree(x1=..., x2=...)
output = SomePyTree(y1=..., y2=...)
# Update step_state for next step call
new_ss = ss.replace(state=new_state)
return new_ss, output # Sends output
Next steps
If this quick start has got you interested, then have a look at the sim2real.ipynb notebook for an example of a sim-to-real workflow using Rex.
Citation
If you are using rex for your scientific publications, please cite:
@article{heijden2024rex,
title={{REX: GPU-Accelerated Sim2Real Framework with Delay and Dynamics Estimation}},
author={van der Heijden, Bas and Kober, Jens and Babuska, Robert and Ferranti, Laura},
journal={Transactions on Machine Learning Research (TMLR)},
year={2025}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file rex_lib-0.0.10.tar.gz
.
File metadata
- Download URL: rex_lib-0.0.10.tar.gz
- Upload date:
- Size: 112.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.11
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 00407860321ac209948a732a9c0babd305b7d0ba8b0b367f12d5ed7c59210579 |
|
MD5 | 5e9953ffa5b8662588f9aec4bca36ef5 |
|
BLAKE2b-256 | 33d7c5b62d4626bb91953fee71c167bda91bd2486bc772c59d1f740e649ff44b |
File details
Details for the file rex_lib-0.0.10-py3-none-any.whl
.
File metadata
- Download URL: rex_lib-0.0.10-py3-none-any.whl
- Upload date:
- Size: 128.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.11
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a1f9aabcefdac18a9882c21187784ad97cde6b17496530227d44f6cbe256dc27 |
|
MD5 | 73c498ccf2ea6871a4be6dd02d52affe |
|
BLAKE2b-256 | 2fd84281cbfce0f88a9be82a521df9964ded83d9392f997e14ebf91d6893330d |