Skip to main content

A simple runtime assert library for tensor-based frameworks.

Project description

tensor-shape-assert

Runtime tensor shape and dtype checking through type annotations.

tensor-shape-assert validates the shapes (and optionally dtypes) of array-like objects at function call time, based on annotations you already write. Shared dimension variables are automatically inferred and matched across all annotated parameters and return values — a mismatch raises a clear error before your computation runs.

Compatible with any array library that exposes a .shape property, including NumPy, PyTorch, JAX, and TensorFlow.

Features

  • Runtime shape validation via ShapedTensor["..."] type annotations
  • Shape variables inferred from — and matched across — multiple parameters and return values
  • Batch dimension support with named and unnamed ellipsis tokens (..., ...B)
  • Dtype annotations (bool, int8, float32, complex128, …)
  • Optional and nested annotations (tuples, lists, NamedTuple)
  • int parameters automatically promoted to shape variables
  • Per-function and global check modes (always, once, never) for zero-overhead production deploys
  • Compatible with static type checkers (MyPy, Pyright) via ShapedLiteral aliases

Installation

pip install git+https://github.com/leifvan/tensor-shape-assert

Quick Start

import numpy as np
from tensor_shape_assert import check_tensor_shapes, ShapedTensor

@check_tensor_shapes()
def matrix_multiply(
        x: ShapedTensor["batch m k"],
        y: ShapedTensor["batch k n"],
) -> ShapedTensor["batch m n"]:
    return x @ y

matrix_multiply(np.zeros((4, 5, 3)), np.zeros((4, 3, 7)))  # passes
matrix_multiply(np.zeros((4, 5, 3)), np.zeros((4, 2, 7)))  # raises TensorShapeAssertError

The decorator infers batch=4, m=5, k=3 from x and checks that y and the return value are consistent with those values.

Shape Descriptor Syntax

A shape descriptor is a whitespace-separated string (most punctuation is also treated as whitespace). Each token describes one dimension:

Token Meaning
5 Exact size 5
* Wildcard — any size
n Variable — resolved and matched across all annotations that use the same name
... Zero or more batch dimensions (may appear at most once)
...B Named batch dimensions — must match across annotations sharing the same name B
"" / ScalarTensor Scalar (0-dimensional) tensor

Dtype tokens can appear anywhere in the descriptor alongside dimension tokens (see Dtype Annotations).

Core Concepts

Variables

When two parameters share a variable name, their sizes along that dimension must agree:

@check_tensor_shapes()
def add(x: ShapedTensor["n k"], y: ShapedTensor["n k"]) -> ShapedTensor["n k"]:
    return x + y

Variable names can be any identifier not reserved by other rules (integers, *, ..., dtype tokens).

Integers as Shape Variables

int parameters are automatically promoted to shape variables, enabling dynamic shape constraints:

@check_tensor_shapes()
def take_k(x: ShapedTensor["n k"], k: int) -> ShapedTensor["n k"]:
    return x[:, :k]

take_k(np.zeros((10, 4)), k=4)  # passes — k=4 matches x.shape[1]
take_k(np.zeros((10, 4)), k=3)  # raises TensorShapeAssertError

Disable this behaviour with @check_tensor_shapes(ints_to_variables=False).

Batch Dimensions

Use ... for an arbitrary number of leading dimensions:

@check_tensor_shapes()
def normalize(x: ShapedTensor["... d"]) -> ShapedTensor["... d"]:
    return x / np.linalg.norm(x, axis=-1, keepdims=True)

Use a named batch dimension (...B) to enforce that multiple parameters share the same batch shape:

@check_tensor_shapes()
def bilinear(x: ShapedTensor["...B m k"], y: ShapedTensor["...B k n"]) -> ShapedTensor["...B m n"]:
    return x @ y

Dtype Annotations

Add a dtype kind — and optionally a bit width — anywhere in the descriptor:

@check_tensor_shapes()
def safe_mean(x: ShapedTensor["float n k"]) -> ShapedTensor["float n"]:
    return x.mean(axis=-1)

Supported dtype tokens:

Token Accepted dtypes
bool boolean
int, int8, int16, int32, int64 signed integer
uint, uint8, uint16, uint32, uint64 unsigned integer
integral any integer (signed or unsigned)
float, float16, float32, float64 real floating-point
complex, complex64, complex128 complex floating-point
numeric any non-boolean numeric dtype

These tokens are reserved and cannot be used as variable names.

Optional and Nested Annotations

Annotations can be arbitrarily nested in tuples or lists. Mark an optional tensor with | None:

@check_tensor_shapes()
def process(
        x: tuple[ShapedTensor["n k"], ShapedTensor["n"] | None],
        y: ShapedTensor["n 3"],
) -> ShapedTensor["n"]:
    a, b = x
    result = y.sum(axis=1)
    if b is not None:
        result = result + b
    return result

NamedTuple classes are also supported — apply the decorator to the class itself.

API Reference

check_tensor_shapes

@check_tensor_shapes(
    constraints=None,
    ints_to_variables=True,
    check_mode=None,
    include_outer_variables=None,
    disable_union_warning=False,
)

Decorator that enables shape checking for a function, class __init__, or NamedTuple class.

Parameter Type Default Description
constraints list[str | Callable] None Extra constraints on shape variables. String expressions are evaluated (e.g. "a == 2 * b"); callables receive the variable dict and must return bool. Checked before and after the wrapped call.
ints_to_variables bool True Promote int parameters to shape variables.
check_mode "always" | "once" | "never" | None None Per-function check mode, overrides the global setting.
include_outer_variables bool | None None Inherit shape variables from an enclosing check_tensor_shapes scope. Defaults to False for functions and True for NamedTuple instances.
disable_union_warning bool False Suppress the warning about partially unsupported union types.

set_global_check_mode

set_global_check_mode(mode: Literal["always", "once", "never"])

Set the global check mode for all @check_tensor_shapes-decorated functions. Per-function check_mode takes precedence when specified.

Mode Behaviour
"always" Check every call (default)
"once" Check each decorated function only on its first call
"never" Disable all shape checks globally

get_shape_variables

get_shape_variables(names: str) -> tuple[int | tuple[int, ...] | None, ...]

Return the current inferred values of shape variables. Must be called from inside a @check_tensor_shapes-wrapped function.

@check_tensor_shapes()
def my_func(x: ShapedTensor["n k 3"]):
    n, k = get_shape_variables("n k")
    print(f"n={n}, k={k}")

my_func(np.zeros((10, 9, 3)))  # prints "n=10, k=9"

assert_shape_here

assert_shape_here(obj_or_shape, descriptor: str) -> None

Validate a tensor or shape tuple against a descriptor from inside a @check_tensor_shapes-wrapped function. Any new variables in the descriptor are registered for subsequent checks, including the function's return annotation.

@check_tensor_shapes()
def my_func(x: ShapedTensor["n k"]) -> ShapedTensor["n m"]:
    y = some_operation(x)
    assert_shape_here(y, "n m")  # registers m; return annotation reuses it
    return y

label_tensor

label_tensor(tensor, label: str | Iterable[str], overwrite: bool = False) -> tensor

Attach one or more labels to a tensor. Labels registered with register_label can appear in shape descriptors and are matched against the tensor's labels at call time.

from tensor_shape_assert import register_label, label_tensor

register_label("encoder_output")

z = label_tensor(encoder(x), "encoder_output")

@check_tensor_shapes()
def decode(z: ShapedTensor["encoder_output n d"]) -> ShapedTensor["n vocab"]:
    ...

register_label

register_label(label: str, constraint_fn: Callable[[array], bool] | None = None)

Register a custom label token for use in shape descriptors.

  • If constraint_fn is None, the label is unconstrained: tensors must be explicitly tagged with label_tensor before being passed to a checked function.
  • If constraint_fn is provided, the label behaves like a dtype annotation: any tensor whose descriptor contains this label is automatically checked by calling constraint_fn(tensor). Constrained labels cannot be assigned via label_tensor.

Trace Utilities

Use the trace API to inspect how shape variables are inferred — useful for debugging:

from tensor_shape_assert import start_trace_recording, stop_trace_recording, trace_records_to_string

start_trace_recording()
my_func(np.zeros((10, 9, 3)))
records = stop_trace_recording()
print(trace_records_to_string(records))
Function Description
start_trace_recording() Begin capturing per-parameter variable assignments
stop_trace_recording() Stop capturing and return the list of TraceRecord objects
trace_records_to_string(records) Format records as an indented, human-readable string

Type-Safe Literal Syntax

For full static type-checker (MyPy, Pyright) compatibility, use ShapedLiteral and the pre-built framework aliases:

import torch
from typing import Literal as L
from tensor_shape_assert import check_tensor_shapes, ShapedTorchLiteral, ShapedLiteral

@check_tensor_shapes()
def my_func(
        x: ShapedTorchLiteral[L["n k"]],
        y: ShapedTorchLiteral[L["k m"]],
) -> ShapedLiteral[torch.Tensor, L["n m"]]:
    return x @ y
Alias Array type
ShapedTorchLiteral[L["..."]] torch.Tensor
ShapedNumpyLiteral[L["..."]] numpy.ndarray
ShapedLiteral[T, L["..."]] Any type T

Extended Example

Tuple inputs, optional parameters, and batch dimensions together:

import torch
from tensor_shape_assert import check_tensor_shapes, ShapedTensor

@check_tensor_shapes()
def attention(
        query: ShapedTensor["...B heads seq_q d"],
        key_value: tuple[
            ShapedTensor["...B heads seq_kv d"],
            ShapedTensor["...B heads seq_kv d"],
        ],
        mask: ShapedTensor["...B 1 seq_q seq_kv"] | None = None,
) -> ShapedTensor["...B heads seq_q d"]:
    keys, values = key_value
    scores = query @ keys.transpose(-2, -1)  # (...B, heads, seq_q, seq_kv)
    if mask is not None:
        scores = scores + mask
    weights = scores.softmax(dim=-1)
    return weights @ values

# All of the following pass:
attention(
    query=torch.zeros(2, 8, 16, 64),
    key_value=(torch.zeros(2, 8, 32, 64), torch.zeros(2, 8, 32, 64)),
)

attention(
    query=torch.zeros(4, 2, 8, 16, 64),  # extra batch dim
    key_value=(torch.zeros(4, 2, 8, 32, 64), torch.zeros(4, 2, 8, 32, 64)),
    mask=torch.zeros(4, 2, 1, 16, 32),
)

Compatibility

tensor-shape-assert works with any array library whose objects expose a .shape property:

License

This project is released under the MIT License. See LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tensor_shape_assert-0.4.2.tar.gz (23.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tensor_shape_assert-0.4.2-py3-none-any.whl (21.7 kB view details)

Uploaded Python 3

File details

Details for the file tensor_shape_assert-0.4.2.tar.gz.

File metadata

  • Download URL: tensor_shape_assert-0.4.2.tar.gz
  • Upload date:
  • Size: 23.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for tensor_shape_assert-0.4.2.tar.gz
Algorithm Hash digest
SHA256 aa4fbdf6fbe864266ae5a801f8ecf4272a699dd5b071560646fcef3bc5447e3e
MD5 cca1f1191f8362bff2bf979d8661d5e5
BLAKE2b-256 0825bb56d5ba8dafc7a48f4046f607bdd35eaf887a01173762b5c19b307d9870

See more details on using hashes here.

Provenance

The following attestation bundles were made for tensor_shape_assert-0.4.2.tar.gz:

Publisher: release.yml on leifvan/tensor-shape-assert

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file tensor_shape_assert-0.4.2-py3-none-any.whl.

File metadata

File hashes

Hashes for tensor_shape_assert-0.4.2-py3-none-any.whl
Algorithm Hash digest
SHA256 080d88877009ca437b05ff2074e4c8a9cba0cb4f5a2f3f5401dd9872d65c0fff
MD5 60fff83c6b02b64523f5b17c0b594b34
BLAKE2b-256 23929930eb69f55e9f7f7d3971c4301110fe023ccbb95d1d346eeed41711e137

See more details on using hashes here.

Provenance

The following attestation bundles were made for tensor_shape_assert-0.4.2-py3-none-any.whl:

Publisher: release.yml on leifvan/tensor-shape-assert

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page