Skip to main content

**JaxADi** is a powerful Python library designed to bridge the gap between `casadi.Function` and JAX-compatible functions.

Project description

JAXADI Logo

JaxADi is a powerful Python library designed to bridge the gap between casadi.Function and JAX-compatible functions. By leveraging the strengths of both CasADi and JAX, JAXADI opens up exciting opportunities for building highly efficient, batchable code that can be executed seamlessly across CPUs, GPUs, and TPUs.

JAXADI can be particularly useful in scenarios involving:

  • Robotics simulations
  • Optimal control problems
  • Machine learning models with complex dynamics
  • Large-scale numerical optimizations

Installation

You can install JAXADI using pip:

pip install jaxadi

For a complete environment setup, we recommend using Conda/Mamba:

mamba env create -f environment.yml

Usage

JAXADI provides a simple and intuitive API:

import casadi as cs
import numpy as np
from jaxadi import translate, convert
from jax import numpy as jnp

x = cs.SX.sym("x", 2, 2)
y = cs.SX.sym("y", 2, 2)
# Define a complex nonlinear function
z = x @ y  # Matrix multiplication
z_squared = z * z  # Element-wise squaring
z_sin = cs.sin(z)  # Element-wise sine
result = z_squared + z_sin  # Element-wise addition
# Create the CasADi function
casadi_fn = cs.Function("complex_nonlinear_func", [x, y], [result])
# Get JAX-compatible function string representation
jax_fn_string = translate(casadi_fn)
print(jax_fn_string)
# Define JAX function from CasADi one
jax_fn = convert(casadi_fn, compile=True)
# Run compiled function
input_x = jnp.array(np.random.rand(2, 2))
input_y = jnp.array(np.random.rand(2, 2))
output = jax_fn(input_x, input_y)

Examples

JAXADI comes with several examples to help you get started:

  1. Basic Translation: Learn how to translate CasADi functions to JAX.

  2. Lowering Operations: Understand the lowering process in JaxADi.

  3. Function Conversion: See how to fully convert CasADi functions to JAX.

  4. Pinocchio Integration: Explore how to convert Pinocchio-based CasADi functions to JAX.

Note: To run the Pinocchio example, ensure you have Pinocchio properly installed in your environment.

Performance Benchmarks

(Consider adding a section about performance comparisons between CasADi and JAXADI-translated functions)

Acknowledgements

This project draws inspiration from cusadi, with a focus on simplicity and JAX integration.

Contact

For questions, issues, or suggestions, please open an issue on our GitHub repository.

We hope JAXADI empowers your numerical computing and optimization tasks! Happy coding!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

jaxadi-0.2.0-py3-none-any.whl (8.7 kB view details)

Uploaded Python 3

File details

Details for the file jaxadi-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: jaxadi-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 8.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.5

File hashes

Hashes for jaxadi-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 8f3bd7869a2867ccc44999b1a682329ec5807aa449c4c77179c987094a5f5d04
MD5 6911b3024156a852d82835cc4abaf8be
BLAKE2b-256 958b9fc64eb24cf79136b1555f637ad220650e351defe3871c73af95f3a26f4b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page