Skip to main content

A JAX 2D Rigid-Body Physics Engine

Project description

Jax2D

Jax2D is a 2D rigid-body physics engine written entirely in JAX and based off the Box2D engine. Unlike other JAX physics engines, Jax2D is dynamic with respect to scene configuration, allowing heterogeneous scenes to be parallelised with vmap. Jax2D was initially created for the backend of the Kinetix project and was developed by Michael_{Matthews, Beukman}.

Why should I use Jax2D?

The main reason to use Jax2D over other JAX physics engines such as Brax or MJX is that Jax2D scenes are (largely) dynamically specified. Jax2D always has O(n^2) runtime with respect to the number of entities in a scene, since we must always calculate the full collision resolution for every pair of entities. This means it is usually not appropriate for simulating scenes with large numbers (>100) of entities.

In short: Jax2D excels at simulating lots of small and diverse scenes in parallel very fast.

Example Usage

Below shows an example of how to use Jax2D to create and run a scene. For the full code see examples/car.py

# Create engine with default parameters
static_sim_params = StaticSimParams()
sim_params = SimParams()
engine = PhysicsEngine(static_sim_params)

# Create scene
sim_state = create_empty_sim(static_sim_params, floor_offset=0.0)

# Create a rectangle for the car body
sim_state, (_, r_index) = add_rectangle_to_scene(
    sim_state, static_sim_params, position=jnp.array([2.0, 1.0]),
    dimensions=jnp.array([1.0, 0.4])
)

# Create circles for the wheels of the car
sim_state, (_, c1_index) = add_circle_to_scene(
    sim_state, static_sim_params, position=jnp.array([1.5, 1.0]), radius=0.35
)
sim_state, (_, c2_index) = add_circle_to_scene(
    sim_state, static_sim_params, position=jnp.array([2.5, 1.0]), radius=0.35
)

# Join the wheels to the car body with revolute joints
# Relative positions are from the centre of masses of each object
sim_state, _ = add_revolute_joint_to_scene(
    sim_state,
    static_sim_params,
    a_index=r_index,
    b_index=c1_index,
    a_relative_pos=jnp.array([-0.5, 0.0]),
    b_relative_pos=jnp.zeros(2),
    motor_on=True,
)
sim_state, _ = add_revolute_joint_to_scene(
    sim_state,
    static_sim_params,
    a_index=r_index,
    b_index=c2_index,
    a_relative_pos=jnp.array([0.5, 0.0]),
    b_relative_pos=jnp.zeros(2),
    motor_on=True,
)

# Add a triangle for a ramp - we fixate the ramp so it can't move
triangle_vertices = jnp.array([[0.5, 0.1], [0.5, -0.1], [-0.5, -0.1]])
sim_state, _ = add_polygon_to_scene(
    sim_state,
    static_sim_params,
    position=jnp.array([2.7, 0.1]),
    vertices=triangle_vertices,
    n_vertices=3,
    fixated=True,
)


# Run scene
step_fn = jax.jit(engine.step)

while True:
    # We activate all motors and thrusters
    actions = jnp.ones(static_sim_params.num_joints + static_sim_params.num_thrusters)
    sim_state, _ = step_fn(sim_state, sim_params, actions)
    
    # Do rendering...

This produces the following scene (rendered with JaxGL)

More Complex Levels

For creating and using more complicated levels, we recommend using the built-in editors provided in Kinetix.

Installation

To use Jax2D in your work you can install via PyPi:

pip install jax2d

If you want to extend Jax2D you can install as follows:

git clone https://github.com/MichaelTMatthews/Jax2D
cd Jax2D
pip install -e ".[dev]"
pre-commit install

See Also

  • 🍎 Box2D The original C physics engine
  • 🤖 Kinetix Jax2D as a reinforcement learning environment
  • 🌐 KinetixJS Jax2D reimplemented in Javascript
  • 🦾 Brax 3D physics in JAX
  • 🦿 MJX MuJoCo in JAX
  • 👨‍💻 JaxGL Rendering in JAX

Citation

If you use Jax2D in your work please cite it as follows:

@article{matthews2024kinetix,
      title={Kinetix: Investigating the Training of General Agents through Open-Ended Physics-Based Control Tasks}, 
      author={Michael Matthews and Michael Beukman and Chris Lu and Jakob Foerster},
      year={2024},
      eprint={2410.23208},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2410.23208}, 
}

Acknowledgements

We would like to thank Erin Catto and Randy Gaul for their invaluable online materials that allowed the creation of this engine. If you would like to develop your own physics engine, we recommend starting here.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax2d-1.0.0.tar.gz (22.5 kB view details)

Uploaded Source

Built Distribution

jax2d-1.0.0-py3-none-any.whl (21.6 kB view details)

Uploaded Python 3

File details

Details for the file jax2d-1.0.0.tar.gz.

File metadata

  • Download URL: jax2d-1.0.0.tar.gz
  • Upload date:
  • Size: 22.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.12

File hashes

Hashes for jax2d-1.0.0.tar.gz
Algorithm Hash digest
SHA256 e808015b45d435be5d24e82943d827dd3c0bfa1944dd0a3b343a70c86a17d8c4
MD5 21988c632c79bdb219c15d96ee128d76
BLAKE2b-256 f01498434207bab3fdcf953de39ba659a4ac5bea3a39e628c552f35150e5b034

See more details on using hashes here.

File details

Details for the file jax2d-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: jax2d-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 21.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.12

File hashes

Hashes for jax2d-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 87486d356933d01225bcc443ca110dd2ab107d2871c322554822116f83ed310c
MD5 5f23fd37cede46bedce659e5dad53e6f
BLAKE2b-256 abe7ddd6d305bff91886cd8af36bb9823b6eb44bfa09d0befbcd993643a9cacb

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page