Skip to main content

There's more to JAX.

Project description

jaxmore

There's more to JAX.

PyPI version PyPI platforms Actions status

This package provides some useful functionality that is missing in base JAX. Major features include:

  • vmap — a drop-in replacement for jax.vmap with static-arg/kwarg support and per-kwarg axis control.
  • bounded_while_loop — a reverse-mode-friendly, bounded while_loop implemented via lax.scan.

Installation

pip install jaxmore

Examples

vmap — static arguments and per-kwarg axis mapping

jaxmore.vmap is a drop-in replacement for jax.vmap. By default it behaves identically:

import jax.numpy as jnp
from jaxmore import vmap


def f(x, *, scale):
    return x * scale


vf = vmap(f)
vf(jnp.arange(3.0), scale=jnp.ones(3))  # Array([0., 1., 2.], dtype=float32)

Static args & kwargs — bake constants into a closure so they never cross the jax.jit boundary, reducing dispatch overhead:

import jax.numpy as jnp
from jaxmore import vmap


def mul(factor, x, *, offset):
    return factor * x + offset


vmul = vmap(mul, static_args=(3.0,), static_kw={"offset": 1.0})
print(vmul(jnp.arange(4.0)))  # Array([ 1.,  4.,  7., 10.], dtype=float32)

Per-kwarg axis control — map, broadcast, or ignore individual keyword arguments independently (not possible with jax.vmap):

import jax.numpy as jnp
from jaxmore import vmap


def h(x, *, a, b):
    return x * a + b


# 'a' is mapped along axis 0, 'b' is broadcast (not mapped)
vh = vmap(h, in_kw={"a": 0, "b": None})
print(vh(jnp.ones(3), a=jnp.array([1.0, 2.0, 3.0]), b=10.0))
# Array([11., 12., 13.], dtype=float32)

Broadcast a kwarg while mapping positional args:

import jax.numpy as jnp
from jaxmore import vmap


def f(x, *, scale):
    return x * scale


vf = vmap(f, in_kw=True, default_kw_axis=None)
print(vf(jnp.arange(3.0), scale=2.0))  # Array([0., 2., 4.], dtype=float32)

Optional JIT — JIT-compile the static-folded function before vmapping:

import jax.numpy as jnp
from jaxmore import vmap


def mul(factor, x, *, offset):
    return factor * x + offset


vmul = vmap(mul, static_args=(3.0,), static_kw={"offset": 1.0}, jit=True)
print(vmul(jnp.arange(4.0)))  # Array([ 1.,  4.,  7., 10.], dtype=float32)

bounded_while_loop

Simple loop over a scalar:

import jax.numpy as jnp
from jaxmore import bounded_while_loop


def cond_fn(x):
    return x < 5


def body_fn(x):
    return x + 1


result = bounded_while_loop(cond_fn, body_fn, jnp.asarray(0), max_steps=10)
print(result)  # Array(5, dtype=int32)

PyTree carry (tuple):

import jax.numpy as jnp
from jaxmore import bounded_while_loop


def cond_fn(state):
    x, _ = state
    return x < 3


def body_fn(state):
    x, y = state
    return x + 1, y * 2


result = bounded_while_loop(
    cond_fn, body_fn, (jnp.asarray(0), jnp.asarray(1)), max_steps=5
)
print(result)  # (Array(3, dtype=int32), Array(8, dtype=int32))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxmore-0.2.0.tar.gz (158.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jaxmore-0.2.0-py3-none-any.whl (13.1 kB view details)

Uploaded Python 3

File details

Details for the file jaxmore-0.2.0.tar.gz.

File metadata

  • Download URL: jaxmore-0.2.0.tar.gz
  • Upload date:
  • Size: 158.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for jaxmore-0.2.0.tar.gz
Algorithm Hash digest
SHA256 f8dc5999622fdf7f6241875d1b43ffc9c615b9d809171efe297ed8ad40bddf4a
MD5 0c0ed8e5d9b0ef233a02627f6931fc39
BLAKE2b-256 49fb487e7d0b23827dcd37e28bee8324936d4c5fd974055d1e1615f5fa468f72

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxmore-0.2.0.tar.gz:

Publisher: cd.yml on GalacticDynamics/jaxmore

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jaxmore-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: jaxmore-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 13.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for jaxmore-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b03d43dd98ecdfba637defa4d9597a077af1d91e34aa48a58c3a5159b25f7e9b
MD5 bf37481a129826583b187e7cd7bad19d
BLAKE2b-256 c87cdc6861979b70f79683a7482a0e5e5cee27d0e657fd8a635a27f7f190da8c

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxmore-0.2.0-py3-none-any.whl:

Publisher: cd.yml on GalacticDynamics/jaxmore

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page