A JAX 2D Rigid-Body Physics Engine
Project description
Jax2D
Jax2D is a 2D rigid-body physics engine written entirely in JAX and based off the Box2D engine.
Unlike other JAX physics engines, Jax2D is dynamic with respect to scene configuration, allowing heterogeneous scenes to be parallelised with vmap.
Jax2D was initially created for the backend of the Kinetix project and was developed by Michael_{Matthews, Beukman}.
When should I use Jax2D?
The main reason to use Jax2D over other JAX physics engines such as Brax or MJX is that Jax2D scenes are (largely) dynamically specified. However, Jax2D always has O(n^2) runtime with respect to the number of entities in a scene, since we must always calculate the full collision resolution for every pair of entities. This means it is usually not appropriate for simulating scenes with large numbers (>100) of entities.
In short: Jax2D excels at simulating lots of small and diverse scenes in parallel very fast.
Example Usage
Below shows an example of how to use Jax2D to create and run a scene. For the full code see examples/car.py. Also see our docs for more details on how Jax2D works.
# Create engine with default parameters
static_sim_params = StaticSimParams()
sim_params = SimParams()
engine = PhysicsEngine(static_sim_params)
# Create scene
sim_state = create_empty_sim(static_sim_params, floor_offset=0.0)
# Create a rectangle for the car body
sim_state, (_, r_index) = add_rectangle_to_scene(
sim_state, static_sim_params, position=jnp.array([2.0, 1.0]),
dimensions=jnp.array([1.0, 0.4])
)
# Create circles for the wheels of the car
sim_state, (_, c1_index) = add_circle_to_scene(
sim_state, static_sim_params, position=jnp.array([1.5, 1.0]), radius=0.35
)
sim_state, (_, c2_index) = add_circle_to_scene(
sim_state, static_sim_params, position=jnp.array([2.5, 1.0]), radius=0.35
)
# Join the wheels to the car body with revolute joints
# Relative positions are from the centre of masses of each object
sim_state, _ = add_revolute_joint_to_scene(
sim_state,
static_sim_params,
a_index=r_index,
b_index=c1_index,
a_relative_pos=jnp.array([-0.5, 0.0]),
b_relative_pos=jnp.zeros(2),
motor_on=True,
)
sim_state, _ = add_revolute_joint_to_scene(
sim_state,
static_sim_params,
a_index=r_index,
b_index=c2_index,
a_relative_pos=jnp.array([0.5, 0.0]),
b_relative_pos=jnp.zeros(2),
motor_on=True,
)
# Add a triangle for a ramp - we fixate the ramp so it can't move
triangle_vertices = jnp.array([[0.5, 0.1], [0.5, -0.1], [-0.5, -0.1]])
sim_state, _ = add_polygon_to_scene(
sim_state,
static_sim_params,
position=jnp.array([2.7, 0.1]),
vertices=triangle_vertices,
n_vertices=3,
fixated=True,
)
# Run scene
step_fn = jax.jit(engine.step)
while True:
# We activate all motors and thrusters
actions = jnp.ones(static_sim_params.num_joints + static_sim_params.num_thrusters)
sim_state, _ = step_fn(sim_state, sim_params, actions)
# Do rendering...
This produces the following scene (rendered with JaxGL)
More Complex Levels
For creating and using more complicated levels, we recommend using the built-in editors provided in Kinetix (or the online version available here).
Installation
To use Jax2D in your work you can install via PyPi:
pip install jax2d
If you want to extend Jax2D you can install as follows:
git clone https://github.com/MichaelTMatthews/Jax2D
cd Jax2D
pip install -e ".[dev]"
pre-commit install
See Also
- 🍎 Box2D The original C physics engine
- 🤖 Kinetix Jax2D as a reinforcement learning environment
- 🌐 Kinetix.js Jax2D reimplemented in Javascript, with a live demo here.
- 🦾 Brax 3D physics in JAX
- 🦿 MJX MuJoCo in JAX
- 👨💻 JaxGL Rendering in JAX
Citation
If you use Jax2D in your work please cite it as follows:
@article{matthews2024kinetix,
title={Kinetix: Investigating the Training of General Agents through Open-Ended Physics-Based Control Tasks},
author={Michael Matthews and Michael Beukman and Chris Lu and Jakob Foerster},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://arxiv.org/abs/2410.23208}
}
Acknowledgements
We would like to thank Erin Catto and Randy Gaul for their invaluable online materials that allowed the creation of this engine. If you would like to develop your own physics engine, we recommend starting here.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jax2d-1.0.2.tar.gz.
File metadata
- Download URL: jax2d-1.0.2.tar.gz
- Upload date:
- Size: 24.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
cf3b9dce5845f177e31c0830e6afcbda4eb18f686a86f9e3dfec15f682c78133
|
|
| MD5 |
e957479d129c6ccf298649a51ea387f1
|
|
| BLAKE2b-256 |
165b1046bdf4773207ba3ee4ba31c3077427bc6fb10389172bdbdba6771b428c
|
File details
Details for the file jax2d-1.0.2-py3-none-any.whl.
File metadata
- Download URL: jax2d-1.0.2-py3-none-any.whl
- Upload date:
- Size: 23.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7eec20669461e6da4b3e5673a93376d48fa2806422361d07462fbbe36d0b27cb
|
|
| MD5 |
c3bbbf35b20cf61199e805139ac6c631
|
|
| BLAKE2b-256 |
e01e621adb89a29cea85c479fefa5fa0110251dec11d046ac1905eb0d972ca9e
|