A JAX wrapper around ceviche to make interoperability easier. In the future it might make sense to update ceviche itself to use JAX internally.
Project description
javiche
Small package to enable using ceviche with a JAX optimizer easily.
Install
This package is not yet published. As soon as it is install with:
pip install javiche
or
conda install javiche
How to use
Import the decorator
from javiche import jaxit
decorate your function (will be differentiated using ceviches jacobian -> HIPS autograd)
@jaxit()
def square(A):
"""squares number/array"""
return A**2
Now you can use jax as usual:
grad_fn = jax.grad(square)
grad_fn(2.0)
Array(4., dtype=float32, weak_type=True)
In this toy example that was already possible without the jaxit()
decorator. However jaxit()
decorated functions can contain autograd
operators (but no jax operators):
import autograd.numpy as npa
def sin(A):
"""computes sin of number/array using autograds numpy"""
return npa.sin(A)
grad_sin = jax.grad(sin)
try:
print(grad_sin(0.0))
except Exception as e:
print(e)
The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray(0.0, dtype=float32, weak_type=True)>with<JVPTrace(level=2/0)> with
primal = 0.0
tangent = Traced<ShapedArray(float32[], weak_type=True)>with<JaxprTrace(level=1/0)> with
pval = (ShapedArray(float32[], weak_type=True), None)
recipe = LambdaBinding()
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
@jaxit()
def cos(A):
"""computes sin of number/array using autograds numpy"""
return npa.cos(A)
grad_cos = jax.grad(cos)
try:
print(grad_cos(0.0))
except Exception as e:
print(e)
-0.0
Usecase
This library is intended for use with ceviche, while running a JAX optimization stack as demonstated in the inverse design example
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file javiche-0.0.6.tar.gz
.
File metadata
- Download URL: javiche-0.0.6.tar.gz
- Upload date:
- Size: 10.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b429dbbede453ef3bd5d0b5fb52fb516bb3e5c3249fbc4b8a8cd15280974bc4f |
|
MD5 | 24a29b8421a963144ca0d3dac979ee21 |
|
BLAKE2b-256 | 552e2f9b5a30129894743ea116982a4d1984ba09e2d6d5d5ad24ededea1355d5 |
File details
Details for the file javiche-0.0.6-py3-none-any.whl
.
File metadata
- Download URL: javiche-0.0.6-py3-none-any.whl
- Upload date:
- Size: 9.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 79b29842e01c0befe13e10e2c262a5454525f843160fb7f15b4d8e155cbc90bd |
|
MD5 | d874ce8c1a5e07e8c0cb7f75f0461723 |
|
BLAKE2b-256 | 1f1fadd69e214fcef83cd74aaf10ded79a9e367ed551c5a8732e665809c8ee60 |