Skip to main content

Tesseract JAX executes Tesseracts as part of JAX programs, with full support for function transformations like JIT, `grad`, and more.

Project description

Tesseract-JAX

Tesseract-JAX is a lightweight extension to Tesseract Core that makes Tesseracts look and feel like regular JAX primitives, and makes them jittable, differentiable, and composable.

Read the docs | Explore the examples | Report an issue | Talk to the community | Contribute


The API of Tesseract-JAX consists of a single function, apply_tesseract(tesseract_client, inputs), which is fully traceable by JAX. This enables end-to-end autodifferentiation and JIT compilation of Tesseract-based pipelines:

@jax.jit
def vector_sum(x, y):
    res = apply_tesseract(vectoradd_tesseract, {"a": {"v": x}, "b": {"v": y}})
    return res["vector_add"]["result"].sum()

jax.grad(vector_sum)(x, y) # 🎉

Quick start

[!NOTE] Before proceeding, make sure you have a working installation of Docker and a modern Python installation (Python 3.10+).

[!IMPORTANT] For more detailed installation instructions, please refer to the Tesseract Core documentation.

  1. Install Tesseract-JAX:

    $ pip install tesseract-jax
    
  2. Build an example Tesseract:

    $ git clone https://github.com/pasteurlabs/tesseract-jax
    $ tesseract build tesseract-jax/examples/simple/vectoradd_jax
    
  3. Use it as part of a JAX program via the JAX-native apply_tesseract function:

    import jax
    import jax.numpy as jnp
    from tesseract_core import Tesseract
    from tesseract_jax import apply_tesseract
    
    # Load the Tesseract
    t = Tesseract.from_image("vectoradd_jax")
    t.serve()
    
    # Run it with JAX
    x = jnp.ones((1000,))
    y = jnp.ones((1000,))
    
    def vector_sum(x, y):
        res = apply_tesseract(t, {"a": {"v": x}, "b": {"v": y}}, vmap_method="sequential")
        return res["vector_add"]["result"].sum()
    
    vector_sum(x, y) # success!
    
    # You can also use it with JAX transformations like JIT and grad
    vector_sum_jit = jax.jit(vector_sum)
    vector_sum_jit(x, y)
    
    vector_sum_grad = jax.grad(vector_sum)
    vector_sum_grad(x, y)
    
    # vmap requires an explicit vmap_method — "sequential" is safe but slow
    # while "auto_experimental" or "expand_dims" is more efficient for Tesseracts that support batching.
    # See https://docs.pasteurlabs.ai/projects/tesseract-jax/latest/content/vmap-methods.html
    vector_sum_vmap = jax.vmap(vector_sum)
    vector_sum_vmap(x.reshape(10, 100), y.reshape(10, 100))
    

[!TIP] Now you're ready to jump into our examples for more ways to use Tesseract-JAX.

Sharp edges

  • Additional required endpoints: Tesseract-JAX requires the abstract_eval Tesseract endpoint to be defined when used in conjunction with automatic differentiation and JAX transformations. This is because JAX, in these cases, mandates abstract evaluation of all operations before they are executed. Additionally, many gradient transformations like jax.grad require vector_jacobian_product to be defined.

[!TIP] When creating a new Tesseract based on a JAX function, use tesseract init --recipe jax to define all required endpoints automatically, including abstract_eval and vector_jacobian_product.

License

Tesseract-JAX is licensed under the Apache License 2.0 and is free to use, modify, and distribute (under the terms of the license).

Tesseract is a registered trademark of Pasteur Labs, Inc. and may not be used without permission.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tesseract_jax-0.3.0.tar.gz (42.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tesseract_jax-0.3.0-py3-none-any.whl (28.4 kB view details)

Uploaded Python 3

File details

Details for the file tesseract_jax-0.3.0.tar.gz.

File metadata

  • Download URL: tesseract_jax-0.3.0.tar.gz
  • Upload date:
  • Size: 42.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for tesseract_jax-0.3.0.tar.gz
Algorithm Hash digest
SHA256 991c20bb205df7db14dbdd6ddf1f05e9e7637d2ee593ec2e8b8a76074345e6af
MD5 c7eaf982e8fa85dfd98475b288421d60
BLAKE2b-256 2b9974e978e0878324e5eb3238f3072718757d148c7861e02c7c9334363ba1fb

See more details on using hashes here.

Provenance

The following attestation bundles were made for tesseract_jax-0.3.0.tar.gz:

Publisher: publish.yml on pasteurlabs/tesseract-jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file tesseract_jax-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: tesseract_jax-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 28.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for tesseract_jax-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1a8825dfec38113bda481dcfdc8345121b5b50e2ca2266ef033e5c7695adbdb3
MD5 08f02fb0f13710ddfe3e280c8aa23bea
BLAKE2b-256 0a452146b322612760795f8543fb41bba4279507b823c41b9c7b11496af6656c

See more details on using hashes here.

Provenance

The following attestation bundles were made for tesseract_jax-0.3.0-py3-none-any.whl:

Publisher: publish.yml on pasteurlabs/tesseract-jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page