A JAX 2D Rigid-Body Physics Engine
Project description
Jax2D
Jax2D is a 2D rigid-body physics engine written entirely in JAX and based off the Box2D engine.
Unlike other JAX physics engines, Jax2D is dynamic with respect to scene configuration, allowing heterogeneous scenes to be parallelised with vmap
.
Jax2D was initially created for the backend of the Kinetix project and was developed by Michael_{Matthews, Beukman}.
Why should I use Jax2D?
The main reason to use Jax2D over other JAX physics engines such as Brax or MJX is that Jax2D scenes are (largely) dynamically specified. Jax2D always has O(n^2) runtime with respect to the number of entities in a scene, since we must always calculate the full collision resolution for every pair of entities. This means it is usually not appropriate for simulating scenes with large numbers (>100) of entities.
In short: Jax2D excels at simulating lots of small and diverse scenes in parallel very fast.
Example Usage
Below shows an example of how to use Jax2D to create and run a scene. For the full code see examples/car.py
# Create engine with default parameters
static_sim_params = StaticSimParams()
sim_params = SimParams()
engine = PhysicsEngine(static_sim_params)
# Create scene
sim_state = create_empty_sim(static_sim_params, floor_offset=0.0)
# Create a rectangle for the car body
sim_state, (_, r_index) = add_rectangle_to_scene(
sim_state, static_sim_params, position=jnp.array([2.0, 1.0]),
dimensions=jnp.array([1.0, 0.4])
)
# Create circles for the wheels of the car
sim_state, (_, c1_index) = add_circle_to_scene(
sim_state, static_sim_params, position=jnp.array([1.5, 1.0]), radius=0.35
)
sim_state, (_, c2_index) = add_circle_to_scene(
sim_state, static_sim_params, position=jnp.array([2.5, 1.0]), radius=0.35
)
# Join the wheels to the car body with revolute joints
# Relative positions are from the centre of masses of each object
sim_state, _ = add_revolute_joint_to_scene(
sim_state,
static_sim_params,
a_index=r_index,
b_index=c1_index,
a_relative_pos=jnp.array([-0.5, 0.0]),
b_relative_pos=jnp.zeros(2),
motor_on=True,
)
sim_state, _ = add_revolute_joint_to_scene(
sim_state,
static_sim_params,
a_index=r_index,
b_index=c2_index,
a_relative_pos=jnp.array([0.5, 0.0]),
b_relative_pos=jnp.zeros(2),
motor_on=True,
)
# Add a triangle for a ramp - we fixate the ramp so it can't move
triangle_vertices = jnp.array([[0.5, 0.1], [0.5, -0.1], [-0.5, -0.1]])
sim_state, _ = add_polygon_to_scene(
sim_state,
static_sim_params,
position=jnp.array([2.7, 0.1]),
vertices=triangle_vertices,
n_vertices=3,
fixated=True,
)
# Run scene
step_fn = jax.jit(engine.step)
while True:
# We activate all motors and thrusters
actions = jnp.ones(static_sim_params.num_joints + static_sim_params.num_thrusters)
sim_state, _ = step_fn(sim_state, sim_params, actions)
# Do rendering...
This produces the following scene (rendered with JaxGL)
More Complex Levels
For creating and using more complicated levels, we recommend using the built-in editors provided in Kinetix.
Installation
To use Jax2D in your work you can install via PyPi:
pip install jax2d
If you want to extend Jax2D you can install as follows:
git clone https://github.com/MichaelTMatthews/Jax2D
cd Jax2D
pip install -e ".[dev]"
pre-commit install
See Also
- 🍎 Box2D The original C physics engine
- 🤖 Kinetix Jax2D as a reinforcement learning environment
- 🌐 KinetixJS Jax2D reimplemented in Javascript
- 🦾 Brax 3D physics in JAX
- 🦿 MJX MuJoCo in JAX
- 👨💻 JaxGL Rendering in JAX
Citation
If you use Jax2D in your work please cite it as follows:
@article{matthews2024kinetix,
title={Kinetix: Investigating the Training of General Agents through Open-Ended Physics-Based Control Tasks},
author={Michael Matthews and Michael Beukman and Chris Lu and Jakob Foerster},
year={2024},
eprint={2410.23208},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2410.23208},
}
Acknowledgements
We would like to thank Erin Catto and Randy Gaul for their invaluable online materials that allowed the creation of this engine. If you would like to develop your own physics engine, we recommend starting here.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file jax2d-1.0.0.tar.gz
.
File metadata
- Download URL: jax2d-1.0.0.tar.gz
- Upload date:
- Size: 22.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.10.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e808015b45d435be5d24e82943d827dd3c0bfa1944dd0a3b343a70c86a17d8c4 |
|
MD5 | 21988c632c79bdb219c15d96ee128d76 |
|
BLAKE2b-256 | f01498434207bab3fdcf953de39ba659a4ac5bea3a39e628c552f35150e5b034 |
File details
Details for the file jax2d-1.0.0-py3-none-any.whl
.
File metadata
- Download URL: jax2d-1.0.0-py3-none-any.whl
- Upload date:
- Size: 21.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.10.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 87486d356933d01225bcc443ca110dd2ab107d2871c322554822116f83ed310c |
|
MD5 | 5f23fd37cede46bedce659e5dad53e6f |
|
BLAKE2b-256 | abe7ddd6d305bff91886cd8af36bb9823b6eb44bfa09d0befbcd993643a9cacb |