Tesseract JAX executes Tesseracts as part of JAX programs, with full support for function transformations like JIT, `grad`, and more.
Project description
Tesseract-JAX
Tesseract-JAX is a lightweight extension to Tesseract Core that makes Tesseracts look and feel like regular JAX primitives, and makes them jittable, differentiable, and composable.
Read the docs | Explore the examples | Report an issue | Talk to the community | Contribute
The API of Tesseract-JAX consists of a single function, apply_tesseract(tesseract_client, inputs), which is fully traceable by JAX. This enables end-to-end autodifferentiation and JIT compilation of Tesseract-based pipelines:
@jax.jit
def vector_sum(x, y):
res = apply_tesseract(vectoradd_tesseract, {"a": {"v": x}, "b": {"v": y}})
return res["vector_add"]["result"].sum()
jax.grad(vector_sum)(x, y) # 🎉
Quick start
[!NOTE] Before proceeding, make sure you have a working installation of Docker and a modern Python installation (Python 3.10+).
[!IMPORTANT] For more detailed installation instructions, please refer to the Tesseract Core documentation.
-
Install Tesseract-JAX:
$ pip install tesseract-jax
-
Build an example Tesseract:
$ git clone https://github.com/pasteurlabs/tesseract-jax $ tesseract build tesseract-jax/examples/simple/vectoradd_jax
-
Use it as part of a JAX program via the JAX-native
apply_tesseractfunction:import jax import jax.numpy as jnp from tesseract_core import Tesseract from tesseract_jax import apply_tesseract # Load the Tesseract t = Tesseract.from_image("vectoradd_jax") t.serve() # Run it with JAX x = jnp.ones((1000,)) y = jnp.ones((1000,)) def vector_sum(x, y): res = apply_tesseract(t, {"a": {"v": x}, "b": {"v": y}}) return res["vector_add"]["result"].sum() vector_sum(x, y) # success! # You can also use it with JAX transformations like JIT and grad vector_sum_jit = jax.jit(vector_sum) vector_sum_jit(x, y) vector_sum_grad = jax.grad(vector_sum) vector_sum_grad(x, y)
[!TIP] Now you're ready to jump into our examples for more ways to use Tesseract-JAX.
Sharp edges
-
Arrays vs. array-like objects: Tesseract-JAX is stricter than Tesseract Core in that all array inputs to Tesseracts must be JAX or NumPy arrays, not just any array-like (such as Python floats or lists). As a result, you may need to convert your inputs to JAX arrays before passing them to Tesseract-JAX, including scalar values.
from tesseract_core import Tesseract from tesseract_jax import apply_tesseract tess = Tesseract.from_image("vectoradd_jax") with Tesseract.from_image("vectoradd_jax") as tess: apply_tesseract(tess, {"a": {"v": [1.0]}, "b": {"v": [2.0]}}) # ❌ raises an error apply_tesseract(tess, {"a": {"v": jnp.array([1.0])}, "b": {"v": jnp.array([2.0])}}) # ✅ works
-
Additional required endpoints: Tesseract-JAX requires the
abstract_evalTesseract endpoint to be defined for all operations. This is because JAX mandates abstract evaluation of all operations before they are executed. Additionally, many gradient transformations likejax.gradrequirevector_jacobian_productto be defined.
[!TIP] When creating a new Tesseract based on a JAX function, use
tesseract init --recipe jaxto define all required endpoints automatically, includingabstract_evalandvector_jacobian_product.
License
Tesseract-JAX is licensed under the Apache License 2.0 and is free to use, modify, and distribute (under the terms of the license).
Tesseract is a registered trademark of Pasteur Labs, Inc. and may not be used without permission.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file tesseract_jax-0.2.1.tar.gz.
File metadata
- Download URL: tesseract_jax-0.2.1.tar.gz
- Upload date:
- Size: 17.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
15053a9bd277f27dd832a5def1211d2beaef8390ca4419fe16162a2e380e2229
|
|
| MD5 |
c6f1733d752d1673c8a138945cf7c2a5
|
|
| BLAKE2b-256 |
e090e2bac23b7c4244170267a3efb100b7c799c576c4e06de463a1ea1277db5b
|
Provenance
The following attestation bundles were made for tesseract_jax-0.2.1.tar.gz:
Publisher:
publish.yml on pasteurlabs/tesseract-jax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
tesseract_jax-0.2.1.tar.gz -
Subject digest:
15053a9bd277f27dd832a5def1211d2beaef8390ca4419fe16162a2e380e2229 - Sigstore transparency entry: 202678109
- Sigstore integration time:
-
Permalink:
pasteurlabs/tesseract-jax@512c50bd6c4ac12cf303e7714c6199a0ccb29faa -
Branch / Tag:
refs/tags/v0.2.1 - Owner: https://github.com/pasteurlabs
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@512c50bd6c4ac12cf303e7714c6199a0ccb29faa -
Trigger Event:
release
-
Statement type:
File details
Details for the file tesseract_jax-0.2.1-py3-none-any.whl.
File metadata
- Download URL: tesseract_jax-0.2.1-py3-none-any.whl
- Upload date:
- Size: 13.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5fb79e41cd1d345b0a3645009dafa420bcbaee0ccf48b4f6ee484338bdcd6d68
|
|
| MD5 |
36b9a21b5331b73a97fe986692bc38b2
|
|
| BLAKE2b-256 |
515e3f26ade0b98e2fac37b60ca3921ab3fe020d22d8385c347430e70a674f69
|
Provenance
The following attestation bundles were made for tesseract_jax-0.2.1-py3-none-any.whl:
Publisher:
publish.yml on pasteurlabs/tesseract-jax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
tesseract_jax-0.2.1-py3-none-any.whl -
Subject digest:
5fb79e41cd1d345b0a3645009dafa420bcbaee0ccf48b4f6ee484338bdcd6d68 - Sigstore transparency entry: 202678117
- Sigstore integration time:
-
Permalink:
pasteurlabs/tesseract-jax@512c50bd6c4ac12cf303e7714c6199a0ccb29faa -
Branch / Tag:
refs/tags/v0.2.1 - Owner: https://github.com/pasteurlabs
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@512c50bd6c4ac12cf303e7714c6199a0ccb29faa -
Trigger Event:
release
-
Statement type: