Skip to main content

No project description provided

Project description

JAX AI Stack

Continuous integration PyPI version Documentation

JAX is a Python package for array-oriented computation and program transformation. Built around it is a growing ecosystem of packages for specialized numerical computing across a range of domains; an up-to-date list of such projects can be found at Awesome JAX.

Though JAX is often compared to neural network libraries like PyTorch, the JAX core package itself contains very little that is specific to neural network models. Instead, JAX encourages modularity, where domain-specific libraries are developed separately from the core package: this helps drive innovation as researchers and other users explore what is possible.

Within this larger, distributed ecosystem, there are a number of projects that Google researchers and engineers have found useful for implementing and deploying the models behind generative AI tools like Imagen, Gemini, and more. The JAX AI stack serves as a single point-of-entry for this suite of libraries, so you can install and begin using many of the same open source packages that Google developers are using in their everyday work.

To get started with the JAX AI stack, you can check out Getting started with JAX. This is still a work-in-progress, please check back for more documentation and tutorials in the coming weeks!

Installing the stack

The stack can be installed with the following command:

pip install jax-ai-stack

This pins particular versions of component projects which are known to work correctly together via the integration tests in this repository. Packages include:

  • JAX: the core JAX package, which includes array operations and program transformations like jit, vmap, grad, etc.
  • flax: build neural networks with JAX
  • ml_dtypes: NumPy dtype extensions for machine learning.
  • optax: gradient processing and optimization in JAX.
  • orbax: checkpointing and persistence utilities for JAX.
  • chex: utilities for writing reliable JAX code.
  • grain: data loading.

Optional packages

Additionally, there are optional packages you can install with pip extras.

The following command:

pip install jax-ai-stack[tfds]

will install a compatible version of tensorflow and tensorflow-datasets.

Hardware support

To install jax-ai-stack with hardware-specific JAX support, add the JAX installation command in the same pip install invocation. For example:

pip install jax-ai-stack "jax[cuda]"  # JAX + AI stack with GPU/CUDA support
pip install jax-ai-stack "jax[tpu]"  # JAX + AI stack with TPU support

For more information on available options for hardware-specific JAX installation, refer to JAX installation.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_ai_stack-2025.9.3.tar.gz (8.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_ai_stack-2025.9.3-py3-none-any.whl (11.2 kB view details)

Uploaded Python 3

File details

Details for the file jax_ai_stack-2025.9.3.tar.gz.

File metadata

  • Download URL: jax_ai_stack-2025.9.3.tar.gz
  • Upload date:
  • Size: 8.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for jax_ai_stack-2025.9.3.tar.gz
Algorithm Hash digest
SHA256 ade4b8fa8fa00398edcd9706cfa1a6cbb069cd90eb619bdf6a945860e8c701af
MD5 00e6ff172ac8dda00875b864e71a7b1d
BLAKE2b-256 1db9f823ab7ca54a6590d104bdd4730d21bf0e90806fcc9a910c3f94efe33bac

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_ai_stack-2025.9.3.tar.gz:

Publisher: release.yml on jax-ml/jax-ai-stack

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jax_ai_stack-2025.9.3-py3-none-any.whl.

File metadata

File hashes

Hashes for jax_ai_stack-2025.9.3-py3-none-any.whl
Algorithm Hash digest
SHA256 2ace322fe490ace64fedc2a8004f1ba693b483f58d7a0b5e6023b30b22f45a72
MD5 6e816e90d5471d708b1ee4880c60cb53
BLAKE2b-256 e2c8024c46a2011e710e96a148e78bf6181c18b52e8e892ccbc7bcd21c12cfd3

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_ai_stack-2025.9.3-py3-none-any.whl:

Publisher: release.yml on jax-ml/jax-ai-stack

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page