Parallelization utilities for JAX
Project description
Parajax
Automagic parallelization of calls to JAX-based functions
Features
- 🚀 Device-parallel execution: run across multiple CPUs, GPUs or TPUs automatically
- 🔄 Drop-in replacement for
jax.vmap - ⚡ JIT-compatible: works with
jax.jitand variants - 🪄 Transparent padding when batch sizes aren’t divisible by number of devices
- 🎯 Simple interface: just decorate your function with
pvmap
Installation
pip install parajax
Example
import multiprocessing
import jax
import jax.numpy as jnp
from parajax import pvmap
jax.config.update("jax_num_cpu_devices", multiprocessing.cpu_count())
# ^ Only needed on CPU: allow JAX to use all CPU cores
@pvmap
def square(x: float) -> float:
return x**2
xs = jnp.arange(97)
ys = square(xs)
That's it! Invocations of square will now be automatically parallelized across all available devices.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
parajax-0.1.1.tar.gz
(7.2 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file parajax-0.1.1.tar.gz.
File metadata
- Download URL: parajax-0.1.1.tar.gz
- Upload date:
- Size: 7.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.8.22
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d636afe16f7a291be6131914844f99c38fb6d0e90c88793de915e606ced8924e
|
|
| MD5 |
b91d830e91519350f4a9920d83f72635
|
|
| BLAKE2b-256 |
129922c6daec12bbe9b14830aa57ec8ac780c9fc7971e7f8d1f1771c0cda920c
|
File details
Details for the file parajax-0.1.1-py3-none-any.whl.
File metadata
- Download URL: parajax-0.1.1-py3-none-any.whl
- Upload date:
- Size: 8.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.8.22
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1cedebe10288ac8aebf3485d8184dd21ca84655ab438366ef61916c34c0a5bb6
|
|
| MD5 |
967bcb671d07d64efa9313181d796f45
|
|
| BLAKE2b-256 |
76fb4f164c696baf05c0efb854f5d42dd8e7653468b4d22cc6e4381434226a7c
|