Parallelization utilities for JAX
Project description
Parajax
Automagic parallelization of calls to JAX-based functions
Features
- 🚀 Device-parallel execution: run across multiple CPUs, GPUs or TPUs automatically
- ⚡ JIT-compatible: works with
jax.jitand variants - 🪄 Automatic handling of input shapes not divisible by the number of devices
- 🎯 Simple interface: just decorate your function with
pvmap
Installation
pip install parajax
Example
import multiprocessing
import jax
import jax.numpy as jnp
from parajax import pvmap
jax.config.update("jax_num_cpu_devices", multiprocessing.cpu_count())
# ^ Only needed on CPU: allow JAX to use all CPU cores
@pvmap
def square(x):
return x**2
xs = jnp.arange(97)
ys = square(xs)
That's it! Invocations of square will now be automatically parallelized across all available devices.
Documentation
For more details, check out the documentation.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
parajax-0.1.3.tar.gz
(8.2 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file parajax-0.1.3.tar.gz.
File metadata
- Download URL: parajax-0.1.3.tar.gz
- Upload date:
- Size: 8.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.8.22
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
13825846f9254720ce5200980da6efa97301dcc585359bc7a46d1e83adeaa4a8
|
|
| MD5 |
f9a20121cc3caeda362e31d54fa31ccd
|
|
| BLAKE2b-256 |
c423f49ea6f91f6b8ef8bc7e5b33682f2738151eb7db3b0f451bb7d29fee7403
|
File details
Details for the file parajax-0.1.3-py3-none-any.whl.
File metadata
- Download URL: parajax-0.1.3-py3-none-any.whl
- Upload date:
- Size: 9.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.8.22
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bf7a547df8a823dab68bfcc2ccc43546a01ccfa3c99b186b6c7d05bd52851775
|
|
| MD5 |
09d7c605d1d933a79502f014ea51b3bf
|
|
| BLAKE2b-256 |
4ea8270d72f1a97c4ed658481009e7c2d9d7933124709c6c7c08c43d57108453
|