Skip to main content

High-level array sharding API for JAX

Project description

Einshard

High-level array sharding API for JAX

Installation

This library requires at least Python 3.12.

pip install einshard

Usage

# initialising JAX CPU backend with 16 devices
n_devices = 16
import os
os.environ['JAX_PLATFORMS'] = 'cpu'
os.environ['XLA_FLAGS'] = os.environ.get('XLA_FLAGS', '') + f' --xla_force_host_platform_device_count={n_devices}'

from einshard import einshard
import jax
import jax.numpy as jnp

a = jnp.zeros((4, 8))
a = einshard(a, 'a b -> * a* b2*')
jax.debug.visualize_arra

Development

python3.12 -m venv venv
. venv/bin/activate
pip install -U pip
pip install -U wheel
pip install "jax[cpu]"
pip install -r requirements.txt

Run test:

python tests/test_einshard.py

Build package:

python -m build

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

einshard-0.0.1.tar.gz (9.8 kB view hashes)

Uploaded Source

Built Distribution

einshard-0.0.1-py3-none-any.whl (9.8 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page