There's more to JAX.
Project description
jaxmore
There's more to JAX.
This package provides some useful functionality that is missing in base JAX.
Major features include:
vmap— a drop-in replacement forjax.vmapwith static-arg/kwarg support and per-kwarg axis control.bounded_while_loop— a reverse-mode-friendly, boundedwhile_loopimplemented vialax.scan.structured— a decorator that applies per-argument and per-return-value transformations at call time, in a structured, declarative way.
Installation
pip install jaxmore
Examples
vmap — static arguments and per-kwarg axis mapping
jaxmore.vmap is a drop-in replacement for jax.vmap. By default it behaves
identically:
import jax.numpy as jnp
from jaxmore import vmap
def f(x, *, scale):
return x * scale
vf = vmap(f)
vf(jnp.arange(3.0), scale=jnp.ones(3)) # Array([0., 1., 2.], dtype=float32)
Static args & kwargs — bake constants into a closure so they never cross the
jax.jit boundary, reducing dispatch overhead:
import jax.numpy as jnp
from jaxmore import vmap
def mul(factor, x, *, offset):
return factor * x + offset
vmul = vmap(mul, static_args=(3.0,), static_kw={"offset": 1.0})
print(vmul(jnp.arange(4.0))) # Array([ 1., 4., 7., 10.], dtype=float32)
Per-kwarg axis control — map, broadcast, or ignore individual keyword
arguments independently (not possible with jax.vmap):
import jax.numpy as jnp
from jaxmore import vmap
def h(x, *, a, b):
return x * a + b
# 'a' is mapped along axis 0, 'b' is broadcast (not mapped)
vh = vmap(h, in_kw={"a": 0, "b": None})
print(vh(jnp.ones(3), a=jnp.array([1.0, 2.0, 3.0]), b=10.0))
# Array([11., 12., 13.], dtype=float32)
Broadcast a kwarg while mapping positional args:
import jax.numpy as jnp
from jaxmore import vmap
def f(x, *, scale):
return x * scale
vf = vmap(f, in_kw=True, default_kw_axis=None)
print(vf(jnp.arange(3.0), scale=2.0)) # Array([0., 2., 4.], dtype=float32)
Optional JIT — JIT-compile the static-folded function before vmapping:
import jax.numpy as jnp
from jaxmore import vmap
def mul(factor, x, *, offset):
return factor * x + offset
vmul = vmap(mul, static_args=(3.0,), static_kw={"offset": 1.0}, jit=True)
print(vmul(jnp.arange(4.0))) # Array([ 1., 4., 7., 10.], dtype=float32)
bounded_while_loop
Simple loop over a scalar:
import jax.numpy as jnp
from jaxmore import bounded_while_loop
def cond_fn(x):
return x < 5
def body_fn(x):
return x + 1
result = bounded_while_loop(cond_fn, body_fn, jnp.asarray(0), max_steps=10)
print(result) # Array(5, dtype=int32)
PyTree carry (tuple):
import jax.numpy as jnp
from jaxmore import bounded_while_loop
def cond_fn(state):
x, _ = state
return x < 3
def body_fn(state):
x, y = state
return x + 1, y * 2
result = bounded_while_loop(
cond_fn, body_fn, (jnp.asarray(0), jnp.asarray(1)), max_steps=5
)
print(result) # (Array(3, dtype=int32), Array(8, dtype=int32))
structured — per-argument and per-return-value transformations
structured is a decorator factory that applies user-supplied callables to
function arguments and return values at call time. It is useful for converting
between raw JAX arrays and richer Python objects (e.g. dataclasses or dicts) at
the boundary of a jax.jit-compiled region.
The examples below use trivial processors (dicts, tuples, etc.) to illustrate
the decorator's mechanics. In practice, you should use structured to convert
between rich domain objects and flat arrays at a JIT boundary.
Bare callable shorthand — process the first positional argument. ins=f is
sugar for ins=((f,),):
from jaxmore import structured
@structured(ins=lambda x: {"value": x})
def increment(obj):
return obj["value"] + 1
print(increment(3)) # 4
Multiple positional processors — one callable per positional param, matched
left-to-right. None skips the corresponding argument:
from jaxmore import structured
to_point = lambda xy: {"x": xy[0], "y": xy[1]}
to_vec = lambda xy: {"dx": xy[0], "dy": xy[1]}
@structured(ins=((to_point, to_vec),))
def translate(pt, v):
return {"x": pt["x"] + v["dx"], "y": pt["y"] + v["dy"]}
print(translate((1, 2), (10, 20))) # {'x': 11, 'y': 22}
VAR_POSITIONAL (*args) — a single processor is applied element-wise to
every value passed via *args:
from jaxmore import structured
@structured(ins=((), lambda v: {"val": v}))
def collect(*args):
return tuple(a["val"] for a in args)
print(collect(1, 2, 4)) # (1, 2, 4)
Keyword-only parameters — matched by name via the third ins slot:
from jaxmore import structured
@structured(ins=((), None, {"cfg": lambda d: {**d, "ready": True}}))
def init(x, *, cfg):
return cfg["ready"], x
print(init(5, cfg={"name": "test"})) # (True, 5)
VAR_KEYWORD
(**kwargs) — a single processor is applied to every value in \*\*kwargs:
from jaxmore import structured
@structured(ins=((), None, {}, lambda v: {"val": v}))
def wrap_kw(**kwargs):
return {k: obj["val"] for k, obj in kwargs.items()}
print(wrap_kw(a=1, b=4)) # {'a': 1, 'b': 4}
Output processing — outs=f applies f to the whole return value. A tuple
applies each processor element-wise; None entries pass through:
from jaxmore import structured
@structured(outs=lambda d: d["result"])
def compute(x):
return {"result": x + 1, "debug": "ok"}
print(compute(4)) # 5
@structured(outs=(lambda d: d["val"], None, lambda d: d["val"]))
def multi_out():
return ({"val": 10}, 2, {"val": 103})
print(multi_out()) # (10, 2, 103)
Combined with JAX / JIT — processors run inside the JIT boundary when
@jax.jit is applied outside @structured. Default parameter values are
filled before processors run:
import jax
import jax.numpy as jnp
from jaxmore import structured
@jax.jit
@structured(
ins=(lambda x: {"val": x},),
outs=lambda d: d["val"],
)
def jit_func(obj):
return {"val": obj["val"] + jnp.asarray(1)}
print(int(jit_func(jnp.asarray(4)))) # 5
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jaxmore-0.3.0.tar.gz.
File metadata
- Download URL: jaxmore-0.3.0.tar.gz
- Upload date:
- Size: 171.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f8c635a41d1b1917dbf7f554de613c646ed9df51a55a0a4e86a9ff33891ca87f
|
|
| MD5 |
9a8f9116b88d183f72732494f075a449
|
|
| BLAKE2b-256 |
1fca54bac0488e80f55a43065133dfe2ed19af98c6394163d2e47a9b75b7d377
|
Provenance
The following attestation bundles were made for jaxmore-0.3.0.tar.gz:
Publisher:
cd.yml on GalacticDynamics/jaxmore
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
jaxmore-0.3.0.tar.gz -
Subject digest:
f8c635a41d1b1917dbf7f554de613c646ed9df51a55a0a4e86a9ff33891ca87f - Sigstore transparency entry: 1244359478
- Sigstore integration time:
-
Permalink:
GalacticDynamics/jaxmore@30acacfcbf6239d7a1ae3cf87e3f701b3e91c459 -
Branch / Tag:
refs/tags/v0.3.0 - Owner: https://github.com/GalacticDynamics
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
cd.yml@30acacfcbf6239d7a1ae3cf87e3f701b3e91c459 -
Trigger Event:
release
-
Statement type:
File details
Details for the file jaxmore-0.3.0-py3-none-any.whl.
File metadata
- Download URL: jaxmore-0.3.0-py3-none-any.whl
- Upload date:
- Size: 22.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1d5c82ad83d9203f80d97cc9b6da554692fc0d36aee64cd2236bee2ecfad5a6a
|
|
| MD5 |
78e1297f4f6eb292e61e9e5cff35c47c
|
|
| BLAKE2b-256 |
01529db18c38ddc8491e3edab107c266f2fe572ce624d964e01612f9dc993a9c
|
Provenance
The following attestation bundles were made for jaxmore-0.3.0-py3-none-any.whl:
Publisher:
cd.yml on GalacticDynamics/jaxmore
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
jaxmore-0.3.0-py3-none-any.whl -
Subject digest:
1d5c82ad83d9203f80d97cc9b6da554692fc0d36aee64cd2236bee2ecfad5a6a - Sigstore transparency entry: 1244359486
- Sigstore integration time:
-
Permalink:
GalacticDynamics/jaxmore@30acacfcbf6239d7a1ae3cf87e3f701b3e91c459 -
Branch / Tag:
refs/tags/v0.3.0 - Owner: https://github.com/GalacticDynamics
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
cd.yml@30acacfcbf6239d7a1ae3cf87e3f701b3e91c459 -
Trigger Event:
release
-
Statement type: