A JAX wrapper around ceviche to make interoperability easier. In the future it might make sense to update ceviche itself to use JAX internally.
Project description
javiche
Small package to enable using ceviche with a JAX optimizer easily.
Install
This package is not yet published. As soon as it is install with:
pip install javiche
or
conda install javiche
How to use
Import the decorator
from javiche import jaxit
decorate your function (will be differentiated using ceviches jacobian -> HIPS autograd)
@jaxit()
def square(A):
"""squares number/array"""
return A**2
Now you can use jax as usual:
grad_fn = jax.grad(square)
grad_fn(2.0)
Array(4., dtype=float32, weak_type=True)
In this toy example that was already possible without the jaxit()
decorator. However jaxit() decorated functions can contain autograd
operators (but no jax operators):
import autograd.numpy as npa
def sin(A):
"""computes sin of number/array using autograds numpy"""
return npa.sin(A)
grad_sin = jax.grad(sin)
try:
print(grad_sin(0.0))
except Exception as e:
print(e)
The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray(0.0, dtype=float32, weak_type=True)>with<JVPTrace(level=2/0)> with
primal = 0.0
tangent = Traced<ShapedArray(float32[], weak_type=True)>with<JaxprTrace(level=1/0)> with
pval = (ShapedArray(float32[], weak_type=True), None)
recipe = LambdaBinding()
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
@jaxit()
def cos(A):
"""computes sin of number/array using autograds numpy"""
return npa.cos(A)
grad_cos = jax.grad(cos)
try:
print(grad_cos(0.0))
except Exception as e:
print(e)
-0.0
Usecase
This library is intended for use with ceviche, while running a JAX optimization stack as demonstated in the inverse design example
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file javiche-0.0.6.tar.gz.
File metadata
- Download URL: javiche-0.0.6.tar.gz
- Upload date:
- Size: 10.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b429dbbede453ef3bd5d0b5fb52fb516bb3e5c3249fbc4b8a8cd15280974bc4f
|
|
| MD5 |
24a29b8421a963144ca0d3dac979ee21
|
|
| BLAKE2b-256 |
552e2f9b5a30129894743ea116982a4d1984ba09e2d6d5d5ad24ededea1355d5
|
File details
Details for the file javiche-0.0.6-py3-none-any.whl.
File metadata
- Download URL: javiche-0.0.6-py3-none-any.whl
- Upload date:
- Size: 9.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
79b29842e01c0befe13e10e2c262a5454525f843160fb7f15b4d8e155cbc90bd
|
|
| MD5 |
d874ce8c1a5e07e8c0cb7f75f0461723
|
|
| BLAKE2b-256 |
1f1fadd69e214fcef83cd74aaf10ded79a9e367ed551c5a8732e665809c8ee60
|