Skip to main content

**JaxADi** is a powerful Python library designed to bridge the gap between `casadi.Function` and JAX-compatible functions.

Project description

CI PyPI version PyPI downloads

JAXADI Logo

JaxADi is a Python library designed to bridge the gap between casadi.Function and JAX-compatible functions. By leveraging the strengths of both CasADi and JAX, JAXADI opens up exciting opportunities for building highly efficient, batchable code that can be executed seamlessly across CPUs, GPUs, and TPUs.

JAXADI can be particularly useful in scenarios involving:

  • Robotics simulations
  • Optimal control problems
  • Machine learning models with complex dynamics

Please dive into quick tutorial to get you up to speed in no time

Installation

You can install JAXADI using pip:

pip install jaxadi

For a complete environment setup for examples, we recommend using Conda/Mamba:

mamba env create -f environment.yml

Usage

JAXADI provides a simple and intuitive API:

import casadi as cs
import numpy as np
from jaxadi import translate, convert
from jax import numpy as jnp

x = cs.SX.sym("x", 2, 2)
y = cs.SX.sym("y", 2, 2)
# Define a complex nonlinear function
z = x @ y  # Matrix multiplication
z_squared = z * z  # Element-wise squaring
z_sin = cs.sin(z)  # Element-wise sine
result = z_squared + z_sin  # Element-wise addition
# Create the CasADi function
casadi_fn = cs.Function("complex_nonlinear_func", [x, y], [result])
# Get JAX-compatible function string representation
jax_fn_string = translate(casadi_fn)
print(jax_fn_string)
# Define JAX function from CasADi one
jax_fn = convert(casadi_fn, compile=True)
# Run compiled function
input_x = jnp.array(np.random.rand(2, 2))
input_y = jnp.array(np.random.rand(2, 2))
output = jax_fn(input_x, input_y)

Note: For now translation does not support functions with very large number of operations, due to the translation implementation. Secret component of translation is work-tree expansion, which might lead to large overhead in number of symbols. We are working on finding the compromise in both speed and extensive functions support.

Examples

JAXADI comes with several examples to help you get started:

  1. Basic Translation: Learn how to translate CasADi functions to JAX.

  2. Lowering Operations: Understand the lowering process in JaxADi.

  3. Function Conversion: See how to fully convert CasADi functions to JAX.

  4. Pendulum Rollout: Batched rollout of the nonlinear passive nonlinear pendulum

  5. Pinocchio Integration: Explore how to convert Pinocchio-based CasADi functions to JAX.

  6. MJX Comparison: Compare the transformed Pinnocchio forward kinematics with one provided by Mujoco MJX

Note: To run the Pinocchio and MJX examples, ensure you have them properly installed in your environment.

Performance Benchmarks

speedup

The process of benchmarking and evaluating the performance of Jaxadi is described in the benchmarks directory.

Contributing

We welcome contributions! Please see our Contributing Guide for more details.

Citation

If you use JaxADi in your research, please cite it as follows:

@misc{jaxadi2024,
  title = {JaxADi: Bridging CasADi and JAX for Efficient Numerical Computing},
  author = {Alentev, Igor and Kozlov, Lev and Nedelchev, Simeon},
  year = {2024},
  url = {https://github.com/based-robotics/jaxadi},
  note = {Accessed: [Insert Access Date]}
}

Acknowledgements

This project draws inspiration from cusadi, with a focus on simplicity and JAX integration.

Contact

For questions, issues, or suggestions, please open an issue on our GitHub repository.

We hope JAXADI empowers your numerical computing and optimization tasks! Happy coding!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jaxadi-0.4.0-py3-none-any.whl (12.4 kB view details)

Uploaded Python 3

File details

Details for the file jaxadi-0.4.0-py3-none-any.whl.

File metadata

  • Download URL: jaxadi-0.4.0-py3-none-any.whl
  • Upload date:
  • Size: 12.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for jaxadi-0.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1ca756d67f1b6bbe8e255eca87817d2cde67eaef6d59a132db637ff81bf0dd8a
MD5 9d6c4d9d73bff4d0968ee1be70d2107f
BLAKE2b-256 5d072fd9e0c3856f6d6cbe96ac4ad850bda94ba0bd68c0aa4eb8b40f15b03fc7

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxadi-0.4.0-py3-none-any.whl:

Publisher: build.yaml on based-robotics/jaxadi

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page