Scipy spatial API for JAX
Project description
jax-scipy-spatial
This package implements scipy.spatial
API for JAX.
Currently the following functions / classes are implemented:
scipy.spatial.transform.Rotation
scipy.spatial.transform.Slerp
Note that much of the code in this module may be too difficult to implement
properly in JAX, (e.g. nearest neighbor search). We request any submissions to
this repo be fully compatible with both vmap
and grad
.
Install
pip install .
Usage
import jax.numpy as jnp
import jax_scipy_spatial.transform as jtr
rotation = jtr.Rotation.from_euler('xyz', jnp.array([0., 0., 180.]), degrees=True)
print(rotation)
Documentation
Please refer to scipy documentation.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
jax-scipy-spatial-0.1.0.tar.gz
(15.8 kB
view hashes)
Built Distribution
Close
Hashes for jax_scipy_spatial-0.1.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | ff0716bd466fed74a23e4be60620ef9afb0be4ad5bebfce57eb44c856ee7047c |
|
MD5 | d703595189a1c04fd5f17ae3d51b7fd9 |
|
BLAKE2b-256 | 9690f524a19016892e28f74e58a52f8fb04076978828ef152ba3df3835ca57e5 |