Skip to main content

A JAX 2D Rigid-Body Physics Engine

Project description

Jax2D

Jax2D is a 2D rigid-body physics engine written entirely in JAX and based off the Box2D engine. Unlike other JAX physics engines, Jax2D is dynamic with respect to scene configuration, allowing heterogeneous scenes to be parallelised with vmap. Jax2D was initially created for the backend of the Kinetix project and was developed by Michael_{Matthews, Beukman}.

When should I use Jax2D?

The main reason to use Jax2D over other JAX physics engines such as Brax or MJX is that Jax2D scenes are (largely) dynamically specified. However, Jax2D always has O(n^2) runtime with respect to the number of entities in a scene, since we must always calculate the full collision resolution for every pair of entities. This means it is usually not appropriate for simulating scenes with large numbers (>100) of entities.

In short: Jax2D excels at simulating lots of small and diverse scenes in parallel very fast.

Example Usage

Below shows an example of how to use Jax2D to create and run a scene. For the full code see examples/car.py. Also see our docs for more details on how Jax2D works.

# Create engine with default parameters
static_sim_params = StaticSimParams()
sim_params = SimParams()
engine = PhysicsEngine(static_sim_params)

# Create scene
sim_state = create_empty_sim(static_sim_params, floor_offset=0.0)

# Create a rectangle for the car body
sim_state, (_, r_index) = add_rectangle_to_scene(
    sim_state, static_sim_params, position=jnp.array([2.0, 1.0]),
    dimensions=jnp.array([1.0, 0.4])
)

# Create circles for the wheels of the car
sim_state, (_, c1_index) = add_circle_to_scene(
    sim_state, static_sim_params, position=jnp.array([1.5, 1.0]), radius=0.35
)
sim_state, (_, c2_index) = add_circle_to_scene(
    sim_state, static_sim_params, position=jnp.array([2.5, 1.0]), radius=0.35
)

# Join the wheels to the car body with revolute joints
# Relative positions are from the centre of masses of each object
sim_state, _ = add_revolute_joint_to_scene(
    sim_state,
    static_sim_params,
    a_index=r_index,
    b_index=c1_index,
    a_relative_pos=jnp.array([-0.5, 0.0]),
    b_relative_pos=jnp.zeros(2),
    motor_on=True,
)
sim_state, _ = add_revolute_joint_to_scene(
    sim_state,
    static_sim_params,
    a_index=r_index,
    b_index=c2_index,
    a_relative_pos=jnp.array([0.5, 0.0]),
    b_relative_pos=jnp.zeros(2),
    motor_on=True,
)

# Add a triangle for a ramp - we fixate the ramp so it can't move
triangle_vertices = jnp.array([[0.5, 0.1], [0.5, -0.1], [-0.5, -0.1]])
sim_state, _ = add_polygon_to_scene(
    sim_state,
    static_sim_params,
    position=jnp.array([2.7, 0.1]),
    vertices=triangle_vertices,
    n_vertices=3,
    fixated=True,
)


# Run scene
step_fn = jax.jit(engine.step)

while True:
    # We activate all motors and thrusters
    actions = jnp.ones(static_sim_params.num_joints + static_sim_params.num_thrusters)
    sim_state, _ = step_fn(sim_state, sim_params, actions)
    
    # Do rendering...

This produces the following scene (rendered with JaxGL)

More Complex Levels

For creating and using more complicated levels, we recommend using the built-in editors provided in Kinetix (or the online version available here).

Installation

To use Jax2D in your work you can install via PyPi:

pip install jax2d

If you want to extend Jax2D you can install as follows:

git clone https://github.com/MichaelTMatthews/Jax2D
cd Jax2D
pip install -e ".[dev]"
pre-commit install

See Also

  • 🍎 Box2D The original C physics engine
  • 🤖 Kinetix Jax2D as a reinforcement learning environment
  • 🌐 Kinetix.js Jax2D reimplemented in Javascript, with a live demo here.
  • 🦾 Brax 3D physics in JAX
  • 🦿 MJX MuJoCo in JAX
  • 👨‍💻 JaxGL Rendering in JAX

Citation

If you use Jax2D in your work please cite it as follows:

@article{matthews2024kinetix,
      title={Kinetix: Investigating the Training of General Agents through Open-Ended Physics-Based Control Tasks}, 
      author={Michael Matthews and Michael Beukman and Chris Lu and Jakob Foerster},
      booktitle={The Thirteenth International Conference on Learning Representations},
      year={2025},
      url={https://arxiv.org/abs/2410.23208}
}

Acknowledgements

We would like to thank Erin Catto and Randy Gaul for their invaluable online materials that allowed the creation of this engine. If you would like to develop your own physics engine, we recommend starting here.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax2d-1.0.2.tar.gz (24.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax2d-1.0.2-py3-none-any.whl (23.8 kB view details)

Uploaded Python 3

File details

Details for the file jax2d-1.0.2.tar.gz.

File metadata

  • Download URL: jax2d-1.0.2.tar.gz
  • Upload date:
  • Size: 24.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for jax2d-1.0.2.tar.gz
Algorithm Hash digest
SHA256 cf3b9dce5845f177e31c0830e6afcbda4eb18f686a86f9e3dfec15f682c78133
MD5 e957479d129c6ccf298649a51ea387f1
BLAKE2b-256 165b1046bdf4773207ba3ee4ba31c3077427bc6fb10389172bdbdba6771b428c

See more details on using hashes here.

File details

Details for the file jax2d-1.0.2-py3-none-any.whl.

File metadata

  • Download URL: jax2d-1.0.2-py3-none-any.whl
  • Upload date:
  • Size: 23.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for jax2d-1.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 7eec20669461e6da4b3e5673a93376d48fa2806422361d07462fbbe36d0b27cb
MD5 c3bbbf35b20cf61199e805139ac6c631
BLAKE2b-256 e01e621adb89a29cea85c479fefa5fa0110251dec11d046ac1905eb0d972ca9e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page