Skip to main content

JAX SparseCore API

Project description

JAX TPU Embedding (JAX SparseCore)

Build and test PyPI version Documentation Status

Quick Overview | Install guide | Reference docs

What is JAX TPU Embedding?

JAX SparseCore provides support for leveraging the SparseCore accelerators present in TPU generations starting with TPU v5.

SparseCores are specialized tiled processors engineered for high-performance acceleration of workloads that involve irregular, sparse memory access and computation. While they excel at tasks like embedding lookups (common in deep learning recommendation models), their capabilities extend to accelerating a variety of other dynamic and sparse workloads on large datasets stored in High Bandwidth Memory (HBM).

This is a research project, not an official Google product. Expect sharp edges. Please help by trying it out, reporting bugs, and letting us know what you think!

Installation

You can install JAX TPU Embedding from PyPI:

pip install jax-tpu-embedding

Note: To use TPU acceleration, you must run in an environment with access to TPU v5+ hardware and have the appropriate jax and jaxlib TPU releases installed (see the JAX installation guide).

Documentation

For detailed guides, specifications, and tutorials, see the JAX TPU Embedding Documentation.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_tpu_embedding-0.1.0.tar.gz (4.1 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

jax_tpu_embedding-0.1.0-cp314-cp314-manylinux_2_31_x86_64.whl (4.1 MB view details)

Uploaded CPython 3.14manylinux: glibc 2.31+ x86-64

jax_tpu_embedding-0.1.0-cp313-cp313-manylinux_2_31_x86_64.whl (4.1 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.31+ x86-64

jax_tpu_embedding-0.1.0-cp312-cp312-manylinux_2_31_x86_64.whl (4.1 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.31+ x86-64

jax_tpu_embedding-0.1.0-cp311-cp311-manylinux_2_31_x86_64.whl (4.1 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.31+ x86-64

File details

Details for the file jax_tpu_embedding-0.1.0.tar.gz.

File metadata

  • Download URL: jax_tpu_embedding-0.1.0.tar.gz
  • Upload date:
  • Size: 4.1 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.13

File hashes

Hashes for jax_tpu_embedding-0.1.0.tar.gz
Algorithm Hash digest
SHA256 71292d1a036a42af6a6179dbc754171e01e51bf2472d073dfce01e2e296ecfa9
MD5 8ad7e15e8c39d2292ba8f91e378a7ff2
BLAKE2b-256 19415de4c65015323b7972f27cce700dbe66e4c4e91ad791999b5d9d8509b0ec

See more details on using hashes here.

File details

Details for the file jax_tpu_embedding-0.1.0-cp314-cp314-manylinux_2_31_x86_64.whl.

File metadata

File hashes

Hashes for jax_tpu_embedding-0.1.0-cp314-cp314-manylinux_2_31_x86_64.whl
Algorithm Hash digest
SHA256 8bbcfaf0ddd28640b6403fde2feedcd544c876d33fe3ce10619b7a049345758f
MD5 4edc26dcea03cd0e6b24a58bf3ef1632
BLAKE2b-256 df201e29c74de6c5ef1ee11ab452f19b1f73dcf5280359e1e118c700fb8e188d

See more details on using hashes here.

File details

Details for the file jax_tpu_embedding-0.1.0-cp313-cp313-manylinux_2_31_x86_64.whl.

File metadata

File hashes

Hashes for jax_tpu_embedding-0.1.0-cp313-cp313-manylinux_2_31_x86_64.whl
Algorithm Hash digest
SHA256 ac70e71df821b0af1ea1d4eed73164fa6a565723fe78a8d5e408fd8c1007951a
MD5 8d53fef7605e88756537bda30895d300
BLAKE2b-256 9742ae29a32cb6e870128fe851f3b6326deab5710ce9816bbdb961613e087097

See more details on using hashes here.

File details

Details for the file jax_tpu_embedding-0.1.0-cp312-cp312-manylinux_2_31_x86_64.whl.

File metadata

File hashes

Hashes for jax_tpu_embedding-0.1.0-cp312-cp312-manylinux_2_31_x86_64.whl
Algorithm Hash digest
SHA256 35dc5141304669aadb0f73065eb3579ba3684e10ef1d0cb59b7b88aa783e03e1
MD5 ceb15a88262d5948cb5a2c4a7803856e
BLAKE2b-256 06f25175116112eeca3ae97ced5268fa4239795b5ff06a2f35a437cb6e5ce326

See more details on using hashes here.

File details

Details for the file jax_tpu_embedding-0.1.0-cp311-cp311-manylinux_2_31_x86_64.whl.

File metadata

File hashes

Hashes for jax_tpu_embedding-0.1.0-cp311-cp311-manylinux_2_31_x86_64.whl
Algorithm Hash digest
SHA256 9843dd6b094bebc0e50ad9291b2f7851f2ff801d710c215ef3d75b0c7b96333e
MD5 dab330e8ea85ebd401e0508bf6769d91
BLAKE2b-256 2312b3be0194d417494283d5c49cad7701d577aa78abdb9b9f54edfb481e07b8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page