Skip to main content

A differentiable physics engine and multibody dynamics library for control and robot learning.

Project description

JaxSim

JaxSim is a differentiable physics engine and multibody dynamics library built with JAX, tailored for control and robotic learning applications.



Features

  • Reduced-coordinate physics engine for fixed-base and floating-base robots.
  • Multibody dynamics library for model-based control algorithms.
  • Fully Python-based, leveraging jax following a functional programming paradigm.
  • Seamless execution on CPUs, GPUs, and TPUs.
  • Supports JIT compilation and automatic vectorization for high performance.
  • Compatible with SDF models and URDF (via sdformat conversion).

Usage

Using JaxSim as simulator

import pathlib

import icub_models
import jax.numpy as jnp

import jaxsim.api as js

# Load the iCub model
model_path = icub_models.get_model_file("iCubGazeboV2_5")

joints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',
          'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch',
          'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch',
          'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',
          'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch',
          'r_ankle_roll')

# Build and reduce the model
model_description = pathlib.Path(model_path)

full_model = js.model.JaxSimModel.build_from_model_description(
    model_description=model_description, time_step=0.0001, is_urdf=True
)

model = js.model.reduce(model=full_model, considered_joints=joints)

# Get the number of degrees of freedom
ndof = model.dofs()

# Initialize data and simulation
# Note that the default data representation is mixed velocity representation
data = js.data.JaxSimModelData.build(
    model=model, base_position=jnp.array([0.0, 0.0, 1.0])
)

T = jnp.arange(start=0, stop=1.0, step=model.time_step)

tau = jnp.zeros(ndof)

# Simulate
for _ in T:
    data = js.model.step(
        model=model, data=data, link_forces=None, joint_force_references=tau
    )

Using JaxSim as a multibody dynamics library

import pathlib

import icub_models
import jax.numpy as jnp

import jaxsim.api as js

# Load the iCub model
model_path = icub_models.get_model_file("iCubGazeboV2_5")

joints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',
          'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch',
          'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch',
          'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',
          'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch',
          'r_ankle_roll')

# Build and reduce the model
model_description = pathlib.Path(model_path)

full_model = js.model.JaxSimModel.build_from_model_description(
    model_description=model_description, time_step=0.0001, is_urdf=True
)

model = js.model.reduce(model=full_model, considered_joints=joints)

# Initialize model data
data = js.data.JaxSimModelData.build(
    model=model,
    base_position=jnp.array([0.0, 0.0, 1.0]),
)

# Frame and dynamics computations
frame_index = js.frame.name_to_idx(model=model, frame_name="l_foot")

# Frame transformation
W_H_F = js.frame.transform(
    model=model, data=data, frame_index=frame_index
)

# Frame Jacobian
W_J_F = js.frame.jacobian(
    model=model, data=data, frame_index=frame_index
)

# Dynamics properties
M = js.model.free_floating_mass_matrix(model=model, data=data)  # Mass matrix
h = js.model.free_floating_bias_forces(model=model, data=data)  # Bias forces
g = js.model.free_floating_gravity_forces(model=model, data=data)  # Gravity forces
C = js.model.free_floating_coriolis_matrix(model=model, data=data)  # Coriolis matrix

# Print dynamics results
print(f"{M.shape=} \n{h.shape=} \n{g.shape=} \n{C.shape=}")

Additional features

  • Full support for automatic differentiation of RBDAs (forward and reverse modes) with JAX.
  • Support for automatically differentiating against kinematics and dynamics parameters.
  • All fixed-step integrators are forward and reverse differentiable.
  • Check the example folder for additional use cases!

[!WARNING] This project is still experimental, APIs could change between releases without notice.

[!NOTE] JaxSim currently focuses on locomotion applications. Only contacts between bodies and smooth ground surfaces are supported.

Installation

With conda

You can install the project using conda as follows:

conda install jaxsim -c conda-forge

You can enforce GPU support, if needed, by also specifying "jaxlib = * = *cuda*".

With pixi

Note

The minimum version of pixi required is 0.39.0.

You can add the jaxsim dependency in pixi project as follows:

pixi add jaxsim

If you are on Linux and you want to use a cuda-powered version of jax, remember to add the appropriate line in the system-requirements table, i.e. adding

[system-requirements]
cuda = "12"

if you are using a pixi.toml file or

[tool.pixi.system-requirements]
cuda = "12"

if you are using a pyproject.toml file.

With pip

You can install the project using pypa/pip, preferably in a virtual environment, as follows:

pip install jaxsim

Check pyproject.toml for the complete list of optional dependencies. You can obtain a full installation using jaxsim[all].

If you need GPU support, follow the official installation instructions of JAX.

Contributors installation (with conda)

If you want to contribute to the project, we recommend creating the following jaxsim conda environment first:

conda env create -f environment.yml

Then, activate the environment and install the project in editable mode:

conda activate jaxsim
pip install --no-deps -e .
Contributors installation (with pixi)

Note

The minimum version of pixi required is 0.39.0.

You can install the default dependencies of the project using pixi as follows:

pixi install

See pixi task list for a list of available tasks.

Documentation

The JaxSim API documentation is available at jaxsim.readthedocs.io.

Overview

Structure of the Python package
# tree -L 2 -I "__pycache__" -I "__init__*" -I "__main__*" src/jaxsim

src/jaxsim
|-- api..........................# Package containing the main functional APIs.
|   |-- actuation_model.py.......# |-- APIs for computing quantities related to the actuation model.
|   |-- common.py................# |-- Common utilities used in the current package.
|   |-- com.py...................# |-- APIs for computing quantities related to the center of mass.
|   |-- contact_model.py.........# |-- APIs for computing quantities related to the contact model.
|   |-- contact.py...............# |-- APIs for computing quantities related to the collidable points.
|   |-- data.py..................# |-- Class storing the data of a simulated model.
|   |-- frame.py.................# |-- APIs for computing quantities related to additional frames.
|   |-- integrators.py...........# |-- APIs for integrating the system dynamics.
|   |-- joint.py.................# |-- APIs for computing quantities related to the joints.
|   |-- kin_dyn_parameters.py....# |-- Class storing kinematic and dynamic parameters of a model.
|   |-- link.py..................# |-- APIs for computing quantities related to the links.
|   |-- model.py.................# |-- Class defining a simulated model and APIs for computing related quantities.
|   |-- ode.py...................# |-- APIs for computing quantities related to the system dynamics.
|   `-- references.py............# `-- Helper class to create references (link forces and joint torques).
|-- exceptions.py................# Module containing functions to raise exceptions from JIT-compiled functions.
|-- logging.py...................# Module containing logging utilities.
|-- math.........................# Package containing mathematical utilities.
|   |-- adjoint.py...............# |-- APIs for creating and manipulating 6D transformations.
|   |-- cross.py.................# |-- APIs for computing cross products of 6D quantities.
|   |-- inertia.py...............# |-- APIs for creating and manipulating 6D inertia matrices.
|   |-- joint_model.py...........# |-- APIs defining the supported joint model and the corresponding transformations.
|   |-- quaternion.py............# |-- APIs for creating and manipulating quaternions.
|   |-- rotation.py..............# |-- APIs for creating and manipulating rotation matrices.
|   |-- skew.py..................# |-- APIs for creating and manipulating skew-symmetric matrices.
|   |-- transform.py.............# |-- APIs for creating and manipulating homogeneous transformations.
|   |-- utils.py.................# |-- Common utilities used in the current package.
|-- mujoco.......................# Package containing utilities to interact with the Mujoco passive viewer.
|   |-- loaders.py...............# |-- Utilities for converting JaxSim models to Mujoco models.
|   |-- model.py.................# |-- Class providing high-level methods to compute quantities using Mujoco.
|   `-- visualizer.py............# `-- Class that simplifies opening the passive viewer and recording videos.
|-- parsers......................# Package containing utilities to parse model descriptions (SDF and URDF models).
|   |-- descriptions/............# |-- Package containing the intermediate representation of a model description.
|   |-- kinematic_graph.py.......# |-- Definition of the kinematic graph associated with a parsed model description.
|   `-- rod/.....................# `-- Package to create the intermediate representation from model descriptions using ROD.
|-- rbda.........................# Package containing the low-level rigid body dynamics algorithms.
|   |-- aba.py...................# |-- The Articulated Body Algorithm.
|   |-- collidable_points.py.....# |-- Kinematics of collidable points.
|   |-- contacts/................# |-- Package containing the supported contact models.
|   |-- crba.py..................# |-- The Composite Rigid Body Algorithm.
|   |-- forward_kinematics.py....# |-- Forward kinematics of the model.
|   |-- jacobian.py..............# |-- Full Jacobian and full Jacobian derivative.
|   |-- rnea.py..................# |-- The Recursive Newton-Euler Algorithm.
|   `-- utils.py.................# `-- Common utilities used in the current package.
|-- terrain......................# Package containing resources to specify the terrain.
|   `-- terrain.py...............# `-- Classes defining the supported terrains.
|-- typing.py....................# Module containing type hints.
`-- utils........................# Package of common utilities.
    |-- jaxsim_dataclass.py......# |-- Utilities to operate on pytree dataclasses.
    |-- tracing.py...............# |-- Utilities to use when JAX is tracing functions.
    `-- wrappers.py..............# `-- Utilities to wrap objects for specific use cases on pytree dataclass attributes.

Credits

The RBDAs are based on the theory of the Rigid Body Dynamics Algorithms book by Roy Featherstone. The algorithms and some simulation features were inspired by its accompanying code.

The development of JaxSim started in late 2021, inspired by early versions of google/brax. At that time, Brax was implemented in maximal coordinates, and we wanted a physics engine in reduced coordinates. We are grateful to the Brax team for their work and for showing the potential of JAX in this field.

Brax v2 was later implemented with reduced coordinates, following an approach comparable to JaxSim. The development then shifted to MJX, which provides a JAX-based implementation of the Mujoco APIs.

The main differences between MJX/Brax and JaxSim are as follows:

  • JaxSim supports out-of-the-box all SDF models with Pose Frame Semantics.
  • JaxSim only supports collisions between points rigidly attached to bodies and a compliant ground surface.

Contributing

We welcome contributions from the community. Please read the contributing guide to get started.

Citing

@software{ferigo_jaxsim_2022,
  author = {Diego Ferigo and Filippo Luca Ferretti and Silvio Traversaro and Daniele Pucci},
  title = {{JaxSim}: A Differentiable Physics Engine and Multibody Dynamics Library for Control and Robot Learning},
  url = {http://github.com/ami-iit/jaxsim},
  year = {2022},
}

Theoretical aspects of JaxSim are based on Chapters 7 and 8 of the following Ph.D. thesis:

@phdthesis{ferigo_phd_thesis_2022,
  title = {Simulation Architectures for Reinforcement Learning applied to Robotics},
  author = {Diego Ferigo},
  school = {University of Manchester},
  type = {PhD Thesis},
  month = {July},
  year = {2022},
}

People

Authors Maintainers

License

BSD3

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxsim-0.6.2.dev273.tar.gz (181.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jaxsim-0.6.2.dev273-py3-none-any.whl (149.1 kB view details)

Uploaded Python 3

File details

Details for the file jaxsim-0.6.2.dev273.tar.gz.

File metadata

  • Download URL: jaxsim-0.6.2.dev273.tar.gz
  • Upload date:
  • Size: 181.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for jaxsim-0.6.2.dev273.tar.gz
Algorithm Hash digest
SHA256 9ab326de5f3f4d4a5eabc026cbb016516e9428c0f9f0af52250d530c0556c3c5
MD5 abdd2292992fb46ccfadf9fd2349a13b
BLAKE2b-256 ba3ec28ec3df47018b19ed7b18ba6245a72158df4ea934750ccd60f82713b2a8

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxsim-0.6.2.dev273.tar.gz:

Publisher: ci_cd.yml on ami-iit/jaxsim

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jaxsim-0.6.2.dev273-py3-none-any.whl.

File metadata

  • Download URL: jaxsim-0.6.2.dev273-py3-none-any.whl
  • Upload date:
  • Size: 149.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for jaxsim-0.6.2.dev273-py3-none-any.whl
Algorithm Hash digest
SHA256 ad9ae30a4c4cfc72eee3dfe52e22ccc3122415a7bb7af6382d5d3d23cff782d0
MD5 043b1350490ad2ad01c68a75055f8a97
BLAKE2b-256 0ff35935bc6209704decee75be19053097dadc387f6def697b05ca9c43552e40

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxsim-0.6.2.dev273-py3-none-any.whl:

Publisher: ci_cd.yml on ami-iit/jaxsim

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page