Skip to main content

Scipy spatial API for JAX

Project description

jax-scipy-spatial

This package implements scipy.spatial API for JAX.

Currently the following functions / classes are implemented:

  • scipy.spatial.transform.Rotation
  • scipy.spatial.transform.Slerp

Note that much of the code in this module may be too difficult to implement properly in JAX, (e.g. nearest neighbor search). We request any submissions to this repo be fully compatible with both vmap and grad.

Install

pip install .

Usage

import jax.numpy as jnp
import jax_scipy_spatial.transform as jtr

rotation = jtr.Rotation.from_euler('xyz', jnp.array([0., 0., 180.]), degrees=True)
print(rotation)

Documentation

Please refer to scipy documentation.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax-scipy-spatial-0.1.0.tar.gz (15.8 kB view hashes)

Uploaded Source

Built Distribution

jax_scipy_spatial-0.1.0-py3-none-any.whl (12.6 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page