A Python library for annotating and validating shape transformations in JAX arrays.
Project description
Annotate Transform
A Python library for annotating and validating shape transformations in JAX arrays.
Overview
annotate_transform provides a decorator that allows you to specify expected input and output shapes for JAX array transformations. This helps catch shape-related bugs early and makes code more self-documenting. It is recommended to only use this in jitted functions
so that it only runs when tracing.
Installation
From pypi
pip install annotate-transform
From source
uv sync
Running tests
uv run --extra test pytest tests/test_annotate_transform.py
Usage
def annotate_transform(
transform: Callable[_inputs_type, _outputs_type], annotation: str
) -> Callable[_inputs_type, _outputs_type]:
"""Annotates and checks transformations to jax.Arrays when the returned
function is invoked.
If annotation does not match the actual transform, raises ValueError.
Example:
>>> in_shape = (5, 3, 24, 24)
>>> a = jnp.ones(in_shape)
>>> b = annotate_transform(jnp.sum, "(b, c, h, w) -> (b, h, w)")(a, axis=1)
>>> c = annotate_transform(jnp.sum, "(5, 3, 24, 24) -> (b, h, 24)")(a, axis=1)
>>> assert (b == c).all()
Matmul example:
>>> A = jnp.ones((5, 10))
>>> B = jnp.ones((10, 5))
>>> C = annotate_transform(jnp.matmul, "((a, b), (b, c)) -> (a, c)")(A, B)
>>> # The following will work as well, although it's less readable than the above
>>> # since we know that for matmuls, 'd' will always be 'b'
>>> works = annotate_transform(jnp.matmul, "((a, b), (d, c)) -> (a, c)")(A, B)
>>> # The following will error out, since we already bound 'a' to 5, and we are trying
>>> # to reuse it in place of 10
>>> error = annotate_transform(jnp.matmul, "((a, a), (b, c)) -> (a, c)")(A, B)
Symbolic dim convention:
>>> A = jnp.ones((5, 10))
>>> B = jnp.ones((10, 5))
>>> # Symbolic dims can be multicharacter
>>> C = annotate_transform(jnp.matmul, "((bananaMan, b), (b, potatoMan)) -> (bananaMan, potatoMan)")(A, B)
>>> # Snake case will work as well
>>> C = annotate_transform(jnp.matmul, "((b_man, b), (b, p_man)) -> (b_man, p_man)")(A, B)
Functions with keyword arguments:
>>> # The order that arguments and keyword arguments are provided is how this function validates
>>> # the shape transformation. For example, consider the following function
>>> def fn(a: jax.Array, *, b: jax.Array, _: int, c: jax.Array):
... if a.shape == b.shape:
... return jnp.max(jnp.concat([a, b]))
... return c
...
>>> # This function has 1 positional argument and 3 keyword arguments, one of those keyword
>>> # arguments being an integer instead of a jax.Array.
>>> # Now consider these invocations
>>> x = jnp.array([0])
>>> y = jnp.array([1])
>>> z = jnp.array([1, 2])
>>> # To validate the inputs, this function just iterates over all positional and keyword
>>> # arguments and collects all jax.Array types.
>>> # Notice how we pass in 3 shapes, even though there are 4 arguments passed in total.
>>> result = annotate_transform(fn, "(1,),(1,),(2,) -> ()")(x, b=y, _=0, c=z)
>>> # If we change the order of the keyword arguments, we also need to change the shape
>>> # annotation to reflect that new ordering
>>> result = annotate_transform(fn, "(1,),(2,),(1,) -> ()")(x, c=z, _=0, b=y)
>>> # For easier readability, it's suggested to place all your non-array keyword arguments
>>> # at the end, like so
>>> result = annotate_transform(fn, "(1,),(2,),(1,) -> ()")(x, c=z, b=y, _=0)
Mathematical expressions in shape annotation:
>>> # Certain shape transformations are functions of other dimensions. For example,
>>> # reshape must preserve the hypervolume of the input array. Thus, we support
>>> # mathematical expressions in the shape annotation.
>>> result = annotate_transform(jnp.reshape, "(b, h, w) -> b * h * w,")(jnp.ones((5, 3, 24)), -1)
>>> # One subtlety is that you must ensure that dims involved in a mathematical expression
>>> # are bound at some point in the shape annotation. Otherwise, this will error out.
>>> # Here is a case that works, in which we check that the expression is bound later on.
>>> result = annotate_transform(jnp.reshape, "(b * h * w), -> b, h, w")(jnp.ones((5 * 3 * 24)), (5, 3, 24))
>>> # However, this will error out, since we never bound 'b'.
>>> result = annotate_transform(jnp.reshape, "(b * h * w), -> c, h, w")(jnp.ones((5 * 3 * 24)), (5, 3, 24))
>>> # So far, the only mathematical expressions that are supported are multiplication, division, addition,
>>> # and subtraction. Note, that for division, we use the symbol '/' instead of the usual '//', but we will
>>> # perform floor division under the hood.
Wildcard support in shape annotation:
>>> batch_size = 2
>>> sequence_length = 3
>>> pytree = {
... "a": jnp.ones((batch_size, sequence_length, 1, 1)),
... "b": jnp.ones((batch_size, sequence_length, 2, 2, 2)),
... }
>>> # Just like with symbolic dimensions, wildcards are bound to the same shape
>>> # for the duration of the transform annotation check.
>>> @partial(annotate_transform, annotation="batch, seq, *feat -> *feat,")
... def transform(arr: jax.Array) -> jax.Array:
... return jnp.sum(arr, axis=(0, 1))
>>> transformed_pytree = jax.tree.map(transform, pytree)
>>> # Note that we only support up to one wildcard per shape, and there
>>> # cannot be a space after the asterisk! Additionally, we cannot have
>>> # wildcard variables with the same name as a concrete dimension.
>>> # Here are examples that will fail:
>>> @partial(annotate_transform, annotation="batch, seq, * feat -> * feat,")
... def transform_with_space_after_wildcard(arr: jax.Array) -> jax.Array:
... # This will error out since there is a space after the asterisk
... return jnp.sum(arr, axis=(0, 1))
>>> @partial(annotate_transform, annotation="a, b, *b -> *b,")
... def transform_with_wildcard_same_name_as_concrete_dim(arr: jax.Array) -> jax.Array:
... # This will error out since we cannot have a wildcard variable with the same name
... # as a concrete dimension
... return jnp.sum(arr, axis=(0, 1))
>>> @partial(annotate_transform, annotation="((a, *b), (*c,)) -> ((a, *b), (*c,))")
... def transform_with_one_wildcard_per_shape(arr: jax.Array, arr2: jax.Array) -> jax.Array:
... # This is fine because we have only one wildcard per shape
... return arr, arr2
>>> @partial(annotate_transform, annotation="((a, *b), (*c, *d)) -> ((a, *b), (*c, *d))")
... def transform_with_multiple_wildcards_in_same_shape(arr: jax.Array, arr2: jax.Array) -> jax.Array:
... # This will error out since we have multiple wildcards in the same shape
... return arr, arr2
>>> # Finally, we cannot use wildcards on concrete dimensions. Here is an example that will fail:
>>> @partial(annotate_transform, annotation="(a, *1 -> a, *1)")
... def transform_with_wildcard_on_concrete_dim(arr: jax.Array) -> jax.Array:
... # This will error out since we cannot use wildcards on concrete dimensions
... return arr
Note: This is expensive to do at runtime, so if using this function, make sure to jit the caller function.
"""
return partial(_transform_and_check, transform, annotation)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file annotate_transform-0.1.3.tar.gz.
File metadata
- Download URL: annotate_transform-0.1.3.tar.gz
- Upload date:
- Size: 10.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.0rc1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
77b8f1065dcf286ba3515188a8d2856582bab15bb84d30628e723f581633bcb5
|
|
| MD5 |
f1a6dd7360018f1a270f04862f04a65d
|
|
| BLAKE2b-256 |
e12adac8ecca4d4d7b8d97702bc978a58cc355fde3b8360cbcc4f1ba9decdbb8
|
File details
Details for the file annotate_transform-0.1.3-py3-none-any.whl.
File metadata
- Download URL: annotate_transform-0.1.3-py3-none-any.whl
- Upload date:
- Size: 10.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.0rc1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
87214e8388ec47186e7e6676732d36a89120b3c66a54946ba04979d510173a14
|
|
| MD5 |
c153cfc9c2a8e49542c3f765d9e31171
|
|
| BLAKE2b-256 |
feef25a3dcf1b2652740bac7dd32936856a2f2873cd86ebb3a3a4950634a4fec
|