Skip to main content

A Python library for annotating and validating shape transformations in JAX arrays.

Project description

Annotate Transform

A Python library for annotating and validating shape transformations in JAX arrays.

Overview

annotate_transform provides a decorator that allows you to specify expected input and output shapes for JAX array transformations. This helps catch shape-related bugs early and makes code more self-documenting. It is recommended to only use this in jitted functions so that it only runs when tracing.

Installation

From pypi

pip install annotate-transform

From source

uv sync

Running tests

uv run --extra test pytest tests/test_annotate_transform.py

Usage

def annotate_transform(
    transform: Callable[_inputs_type, _outputs_type], annotation: str
) -> Callable[_inputs_type, _outputs_type]:
    """Annotates and checks transformations to jax.Arrays when the returned
    function is invoked.
    If annotation does not match the actual transform, raises ValueError.

    Example:
        >>> in_shape = (5, 3, 24, 24)
        >>> a = jnp.ones(in_shape)
        >>> b = annotate_transform(jnp.sum, "(b, c, h, w) -> (b, h, w)")(a, axis=1)
        >>> c = annotate_transform(jnp.sum, "(5, 3, 24, 24) -> (b, h, 24)")(a, axis=1)
        >>> assert (b == c).all()

    Matmul example:
        >>> A = jnp.ones((5, 10))
        >>> B = jnp.ones((10, 5))
        >>> C = annotate_transform(jnp.matmul, "((a, b), (b, c)) -> (a, c)")(A, B)

        >>> # The following will work as well, although it's less readable than the above
        >>> # since we know that for matmuls, 'd' will always be 'b'
        >>> works = annotate_transform(jnp.matmul, "((a, b), (d, c)) -> (a, c)")(A, B)

        >>> # The following will error out, since we already bound 'a' to 5, and we are trying
        >>> # to reuse it in place of 10
        >>> error = annotate_transform(jnp.matmul, "((a, a), (b, c)) -> (a, c)")(A, B)

    Symbolic dim convention:
        >>> A = jnp.ones((5, 10))
        >>> B = jnp.ones((10, 5))
        >>> # Symbolic dims can be multicharacter
        >>> C = annotate_transform(jnp.matmul, "((bananaMan, b), (b, potatoMan)) -> (bananaMan, potatoMan)")(A, B)
        >>> # Snake case will work as well
        >>> C = annotate_transform(jnp.matmul, "((b_man, b), (b, p_man)) -> (b_man, p_man)")(A, B)

    Functions with keyword arguments:
        >>> # The order that arguments and keyword arguments are provided is how this function validates
        >>> # the shape transformation. For example, consider the following function
        >>> def fn(a: jax.Array, *, b: jax.Array, _: int, c: jax.Array):
        ...   if a.shape == b.shape:
        ...     return jnp.max(jnp.concat([a, b]))
        ...   return c
        ...
        >>> # This function has 1 positional argument and 3 keyword arguments, one of those keyword
        >>> # arguments being an integer instead of a jax.Array.
        >>> # Now consider these invocations
        >>> x = jnp.array([0])
        >>> y = jnp.array([1])
        >>> z = jnp.array([1, 2])
        >>> # To validate the inputs, this function just iterates over all positional and keyword
        >>> # arguments and collects all jax.Array types.
        >>> # Notice how we pass in 3 shapes, even though there are 4 arguments passed in total.
        >>> result = annotate_transform(fn, "(1,),(1,),(2,) -> ()")(x, b=y, _=0, c=z)
        >>> # If we change the order of the keyword arguments, we also need to change the shape
        >>> # annotation to reflect that new ordering
        >>> result = annotate_transform(fn, "(1,),(2,),(1,) -> ()")(x, c=z, _=0, b=y)
        >>> # For easier readability, it's suggested to place all your non-array keyword arguments
        >>> # at the end, like so
        >>> result = annotate_transform(fn, "(1,),(2,),(1,) -> ()")(x, c=z, b=y, _=0)

    Mathematical expressions in shape annotation:
        >>> # Certain shape transformations are functions of other dimensions. For example,
        >>> # reshape must preserve the hypervolume of the input array. Thus, we support
        >>> # mathematical expressions in the shape annotation.
        >>> result = annotate_transform(jnp.reshape, "(b, h, w) -> b * h * w,")(jnp.ones((5, 3, 24)), -1)
        >>> # One subtlety is that you must ensure that dims involved in a mathematical expression
        >>> # are bound at some point in the shape annotation. Otherwise, this will error out.
        >>> # Here is a case that works, in which we check that the expression is bound later on.
        >>> result = annotate_transform(jnp.reshape, "(b * h * w), -> b, h, w")(jnp.ones((5 * 3 * 24)), (5, 3, 24))
        >>> # However, this will error out, since we never bound 'b'.
        >>> result = annotate_transform(jnp.reshape, "(b * h * w), -> c, h, w")(jnp.ones((5 * 3 * 24)), (5, 3, 24))
        >>> # So far, the only mathematical expressions that are supported are multiplication, division, addition,
        >>> # and subtraction. Note, that for division, we use the symbol '/' instead of the usual '//', but we will
        >>> # perform floor division under the hood.

    Wildcard support in shape annotation:
        >>> batch_size = 2
        >>> sequence_length = 3
        >>> pytree = {
        ...     "a": jnp.ones((batch_size, sequence_length, 1, 1)),
        ...     "b": jnp.ones((batch_size, sequence_length, 2, 2, 2)),
        ... }
        >>> # Just like with symbolic dimensions, wildcards are bound to the same shape
        >>> # for the duration of the transform annotation check.
        >>> @partial(annotate_transform, annotation="batch, seq, *feat -> *feat,")
        ... def transform(arr: jax.Array) -> jax.Array:
        ...     return jnp.sum(arr, axis=(0, 1))
        >>> transformed_pytree = jax.tree.map(transform, pytree)
        >>> # Note that we only support up to one wildcard per shape, and there
        >>> # cannot be a space after the asterisk! Additionally, we cannot have
        >>> # wildcard variables with the same name as a concrete dimension.
        >>> # Here are examples that will fail:
        >>> @partial(annotate_transform, annotation="batch, seq, * feat -> * feat,")
        ... def transform_with_space_after_wildcard(arr: jax.Array) -> jax.Array:
        ...     # This will error out since there is a space after the asterisk
        ...     return jnp.sum(arr, axis=(0, 1))
        >>> @partial(annotate_transform, annotation="a, b, *b -> *b,")
        ... def transform_with_wildcard_same_name_as_concrete_dim(arr: jax.Array) -> jax.Array:
        ...     # This will error out since we cannot have a wildcard variable with the same name
        ...     # as a concrete dimension
        ...     return jnp.sum(arr, axis=(0, 1))
        >>> @partial(annotate_transform, annotation="((a, *b), (*c,)) -> ((a, *b), (*c,))")
        ... def transform_with_one_wildcard_per_shape(arr: jax.Array, arr2: jax.Array) -> jax.Array:
        ...     # This is fine because we have only one wildcard per shape
        ...     return arr, arr2
        >>> @partial(annotate_transform, annotation="((a, *b), (*c, *d)) -> ((a, *b), (*c, *d))")
        ... def transform_with_multiple_wildcards_in_same_shape(arr: jax.Array, arr2: jax.Array) -> jax.Array:
        ...     # This will error out since we have multiple wildcards in the same shape
        ...     return arr, arr2
        >>> # Finally, we cannot use wildcards on concrete dimensions. Here is an example that will fail:
        >>> @partial(annotate_transform, annotation="(a, *1 -> a, *1)")
        ... def transform_with_wildcard_on_concrete_dim(arr: jax.Array) -> jax.Array:
        ...     # This will error out since we cannot use wildcards on concrete dimensions
        ...     return arr

    Note: This is expensive to do at runtime, so if using this function, make sure to jit the caller function.

    """
    return partial(_transform_and_check, transform, annotation)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

annotate_transform-0.1.3.tar.gz (10.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

annotate_transform-0.1.3-py3-none-any.whl (10.0 kB view details)

Uploaded Python 3

File details

Details for the file annotate_transform-0.1.3.tar.gz.

File metadata

  • Download URL: annotate_transform-0.1.3.tar.gz
  • Upload date:
  • Size: 10.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.0rc1

File hashes

Hashes for annotate_transform-0.1.3.tar.gz
Algorithm Hash digest
SHA256 77b8f1065dcf286ba3515188a8d2856582bab15bb84d30628e723f581633bcb5
MD5 f1a6dd7360018f1a270f04862f04a65d
BLAKE2b-256 e12adac8ecca4d4d7b8d97702bc978a58cc355fde3b8360cbcc4f1ba9decdbb8

See more details on using hashes here.

File details

Details for the file annotate_transform-0.1.3-py3-none-any.whl.

File metadata

File hashes

Hashes for annotate_transform-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 87214e8388ec47186e7e6676732d36a89120b3c66a54946ba04979d510173a14
MD5 c153cfc9c2a8e49542c3f765d9e31171
BLAKE2b-256 feef25a3dcf1b2652740bac7dd32936856a2f2873cd86ebb3a3a4950634a4fec

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page