Skip to main content

A JAX wrapper around ceviche to make interoperability easier. In the future it might make sense to update ceviche itself to use JAX internally.

Project description

javiche

Small package to enable using ceviche with a JAX optimizer easily.

Install

This package is not yet published. As soon as it is install with:

pip install javiche

or

conda install javiche

How to use

Import the decorator

from javiche import jaxit

decorate your function (will be differentiated using ceviches jacobian -> HIPS autograd)

@jaxit()
def square(A):
  """squares number/array"""
  return A**2

Now you can use jax as usual:

grad_fn = jax.grad(square)
grad_fn(2.0)
Array(4., dtype=float32, weak_type=True)

In this toy example that was already possible without the jaxit() decorator. However jaxit() decorated functions can contain autograd operators (but no jax operators):

import autograd.numpy as npa
def sin(A):
  """computes sin of number/array using autograds numpy"""
  return npa.sin(A)
grad_sin = jax.grad(sin)
try:
  print(grad_sin(0.0))
except Exception as e:
  print(e)
The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray(0.0, dtype=float32, weak_type=True)>with<JVPTrace(level=2/0)> with
  primal = 0.0
  tangent = Traced<ShapedArray(float32[], weak_type=True)>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[], weak_type=True), None)
    recipe = LambdaBinding()
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
@jaxit()
def cos(A):
  """computes sin of number/array using autograds numpy"""
  return npa.cos(A)

grad_cos = jax.grad(cos)
try:
  print(grad_cos(0.0))
except Exception as e:
  print(e)
-0.0

Usecase

This library is intended for use with ceviche, while running a JAX optimization stack as demonstated in the inverse design example

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

javiche-0.0.6.tar.gz (10.1 kB view details)

Uploaded Source

Built Distribution

javiche-0.0.6-py3-none-any.whl (9.9 kB view details)

Uploaded Python 3

File details

Details for the file javiche-0.0.6.tar.gz.

File metadata

  • Download URL: javiche-0.0.6.tar.gz
  • Upload date:
  • Size: 10.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.8

File hashes

Hashes for javiche-0.0.6.tar.gz
Algorithm Hash digest
SHA256 b429dbbede453ef3bd5d0b5fb52fb516bb3e5c3249fbc4b8a8cd15280974bc4f
MD5 24a29b8421a963144ca0d3dac979ee21
BLAKE2b-256 552e2f9b5a30129894743ea116982a4d1984ba09e2d6d5d5ad24ededea1355d5

See more details on using hashes here.

File details

Details for the file javiche-0.0.6-py3-none-any.whl.

File metadata

  • Download URL: javiche-0.0.6-py3-none-any.whl
  • Upload date:
  • Size: 9.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.8

File hashes

Hashes for javiche-0.0.6-py3-none-any.whl
Algorithm Hash digest
SHA256 79b29842e01c0befe13e10e2c262a5454525f843160fb7f15b4d8e155cbc90bd
MD5 d874ce8c1a5e07e8c0cb7f75f0461723
BLAKE2b-256 1f1fadd69e214fcef83cd74aaf10ded79a9e367ed551c5a8732e665809c8ee60

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page