Skip to main content

Parallelization utilities for JAX

Project description

Parajax

Automagic parallelization of calls to JAX-based functions

CI Codecov Ruff ty uv Publish PyPI PyPI - Python Version

Features

  • 🚀 Device-parallel execution: run across multiple CPUs, GPUs or TPUs automatically
  • 🔄 Drop-in replacement for jax.vmap
  • JIT-compatible: works with jax.jit and variants
  • 🪄 Transparent padding when batch sizes aren’t divisible by number of devices
  • 🎯 Simple interface: just decorate your function with pvmap

Installation

pip install parajax

Example

import multiprocessing

import jax
import jax.numpy as jnp
from parajax import pvmap

jax.config.update("jax_num_cpu_devices", multiprocessing.cpu_count())
# ^ Only needed on CPU: allow JAX to use all CPU cores

@pvmap
def square(x: float) -> float:
    return x**2

xs = jnp.arange(97)
ys = square(xs)

That's it! Invocations of square will now be automatically parallelized across all available devices.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

parajax-0.1.0.tar.gz (7.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

parajax-0.1.0-py3-none-any.whl (8.1 kB view details)

Uploaded Python 3

File details

Details for the file parajax-0.1.0.tar.gz.

File metadata

  • Download URL: parajax-0.1.0.tar.gz
  • Upload date:
  • Size: 7.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.8.22

File hashes

Hashes for parajax-0.1.0.tar.gz
Algorithm Hash digest
SHA256 12ea2739a866b435f3d613ed2ceaa315b8f0565579d3c2bc5cbf990a0b484b5c
MD5 91ba2aa7650b63a2c9e28a907b2f8af1
BLAKE2b-256 51facc2462966dd4fb90d4b41c885f5e3e924bf35a1f323d87e74e119899e39b

See more details on using hashes here.

File details

Details for the file parajax-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: parajax-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 8.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.8.22

File hashes

Hashes for parajax-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 91ad72f8a42957d50b5010717f4f32e8da613309f0d0d8efa2c35cb6f59bdb84
MD5 4b26b9593dddb12eb11f6e37ecc14802
BLAKE2b-256 d66a95a7d6da999c9b2381cc7a3d1156c157e42a22ea5317a262edd3dffa75d0

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page