There's more to JAX.
Project description
jaxmore
There's more to JAX.
This package provides some useful functionality that is missing in base JAX.
Major features include:
vmap— a drop-in replacement forjax.vmapwith static-arg/kwarg support and per-kwarg axis control.bounded_while_loop— a reverse-mode-friendly, boundedwhile_loopimplemented vialax.scan.
Installation
pip install jaxmore
Examples
vmap — static arguments and per-kwarg axis mapping
jaxmore.vmap is a drop-in replacement for jax.vmap. By default it behaves
identically:
import jax.numpy as jnp
from jaxmore import vmap
def f(x, *, scale):
return x * scale
vf = vmap(f)
vf(jnp.arange(3.0), scale=jnp.ones(3)) # Array([0., 1., 2.], dtype=float32)
Static args & kwargs — bake constants into a closure so they never cross the
jax.jit boundary, reducing dispatch overhead:
import jax.numpy as jnp
from jaxmore import vmap
def mul(factor, x, *, offset):
return factor * x + offset
vmul = vmap(mul, static_args=(3.0,), static_kw={"offset": 1.0})
print(vmul(jnp.arange(4.0))) # Array([ 1., 4., 7., 10.], dtype=float32)
Per-kwarg axis control — map, broadcast, or ignore individual keyword
arguments independently (not possible with jax.vmap):
import jax.numpy as jnp
from jaxmore import vmap
def h(x, *, a, b):
return x * a + b
# 'a' is mapped along axis 0, 'b' is broadcast (not mapped)
vh = vmap(h, in_kw={"a": 0, "b": None})
print(vh(jnp.ones(3), a=jnp.array([1.0, 2.0, 3.0]), b=10.0))
# Array([11., 12., 13.], dtype=float32)
Broadcast a kwarg while mapping positional args:
import jax.numpy as jnp
from jaxmore import vmap
def f(x, *, scale):
return x * scale
vf = vmap(f, in_kw=True, default_kw_axis=None)
print(vf(jnp.arange(3.0), scale=2.0)) # Array([0., 2., 4.], dtype=float32)
Optional JIT — JIT-compile the static-folded function before vmapping:
import jax.numpy as jnp
from jaxmore import vmap
def mul(factor, x, *, offset):
return factor * x + offset
vmul = vmap(mul, static_args=(3.0,), static_kw={"offset": 1.0}, jit=True)
print(vmul(jnp.arange(4.0))) # Array([ 1., 4., 7., 10.], dtype=float32)
bounded_while_loop
Simple loop over a scalar:
import jax.numpy as jnp
from jaxmore import bounded_while_loop
def cond_fn(x):
return x < 5
def body_fn(x):
return x + 1
result = bounded_while_loop(cond_fn, body_fn, jnp.asarray(0), max_steps=10)
print(result) # Array(5, dtype=int32)
PyTree carry (tuple):
import jax.numpy as jnp
from jaxmore import bounded_while_loop
def cond_fn(state):
x, _ = state
return x < 3
def body_fn(state):
x, y = state
return x + 1, y * 2
result = bounded_while_loop(
cond_fn, body_fn, (jnp.asarray(0), jnp.asarray(1)), max_steps=5
)
print(result) # (Array(3, dtype=int32), Array(8, dtype=int32))
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jaxmore-0.2.0.tar.gz.
File metadata
- Download URL: jaxmore-0.2.0.tar.gz
- Upload date:
- Size: 158.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f8dc5999622fdf7f6241875d1b43ffc9c615b9d809171efe297ed8ad40bddf4a
|
|
| MD5 |
0c0ed8e5d9b0ef233a02627f6931fc39
|
|
| BLAKE2b-256 |
49fb487e7d0b23827dcd37e28bee8324936d4c5fd974055d1e1615f5fa468f72
|
Provenance
The following attestation bundles were made for jaxmore-0.2.0.tar.gz:
Publisher:
cd.yml on GalacticDynamics/jaxmore
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
jaxmore-0.2.0.tar.gz -
Subject digest:
f8dc5999622fdf7f6241875d1b43ffc9c615b9d809171efe297ed8ad40bddf4a - Sigstore transparency entry: 999969635
- Sigstore integration time:
-
Permalink:
GalacticDynamics/jaxmore@b778ada5fb822371f6cc9a36e6623e5b3a84525f -
Branch / Tag:
refs/tags/v0.2.0 - Owner: https://github.com/GalacticDynamics
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
cd.yml@b778ada5fb822371f6cc9a36e6623e5b3a84525f -
Trigger Event:
release
-
Statement type:
File details
Details for the file jaxmore-0.2.0-py3-none-any.whl.
File metadata
- Download URL: jaxmore-0.2.0-py3-none-any.whl
- Upload date:
- Size: 13.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b03d43dd98ecdfba637defa4d9597a077af1d91e34aa48a58c3a5159b25f7e9b
|
|
| MD5 |
bf37481a129826583b187e7cd7bad19d
|
|
| BLAKE2b-256 |
c87cdc6861979b70f79683a7482a0e5e5cee27d0e657fd8a635a27f7f190da8c
|
Provenance
The following attestation bundles were made for jaxmore-0.2.0-py3-none-any.whl:
Publisher:
cd.yml on GalacticDynamics/jaxmore
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
jaxmore-0.2.0-py3-none-any.whl -
Subject digest:
b03d43dd98ecdfba637defa4d9597a077af1d91e34aa48a58c3a5159b25f7e9b - Sigstore transparency entry: 999969654
- Sigstore integration time:
-
Permalink:
GalacticDynamics/jaxmore@b778ada5fb822371f6cc9a36e6623e5b3a84525f -
Branch / Tag:
refs/tags/v0.2.0 - Owner: https://github.com/GalacticDynamics
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
cd.yml@b778ada5fb822371f6cc9a36e6623e5b3a84525f -
Trigger Event:
release
-
Statement type: