A simple runtime assert library for tensor-based frameworks.
Project description
tensor-shape-assert
Runtime tensor shape and dtype checking through type annotations.
tensor-shape-assert validates the shapes (and optionally dtypes) of array-like objects at function call time, based on annotations you already write. Shared dimension variables are automatically inferred and matched across all annotated parameters and return values — a mismatch raises a clear error before your computation runs.
Compatible with any array library that exposes a .shape property, including NumPy, PyTorch, JAX, and TensorFlow.
Features
- Runtime shape validation via
ShapedTensor["..."]type annotations - Shape variables inferred from — and matched across — multiple parameters and return values
- Batch dimension support with named and unnamed ellipsis tokens (
...,...B) - Dtype annotations (
bool,int8,float32,complex128, …) - Optional and nested annotations (tuples, lists,
NamedTuple) intparameters automatically promoted to shape variables- Per-function and global check modes (
always,once,never) for zero-overhead production deploys - Compatible with static type checkers (MyPy, Pyright) via
ShapedLiteralaliases
Installation
pip install git+https://github.com/leifvan/tensor-shape-assert
Quick Start
import numpy as np
from tensor_shape_assert import check_tensor_shapes, ShapedTensor
@check_tensor_shapes()
def matrix_multiply(
x: ShapedTensor["batch m k"],
y: ShapedTensor["batch k n"],
) -> ShapedTensor["batch m n"]:
return x @ y
matrix_multiply(np.zeros((4, 5, 3)), np.zeros((4, 3, 7))) # passes
matrix_multiply(np.zeros((4, 5, 3)), np.zeros((4, 2, 7))) # raises TensorShapeAssertError
The decorator infers batch=4, m=5, k=3 from x and checks that y and the return value are consistent with those values.
Shape Descriptor Syntax
A shape descriptor is a whitespace-separated string (most punctuation is also treated as whitespace). Each token describes one dimension:
| Token | Meaning |
|---|---|
5 |
Exact size 5 |
* |
Wildcard — any size |
n |
Variable — resolved and matched across all annotations that use the same name |
... |
Zero or more batch dimensions (may appear at most once) |
...B |
Named batch dimensions — must match across annotations sharing the same name B |
"" / ScalarTensor |
Scalar (0-dimensional) tensor |
Dtype tokens can appear anywhere in the descriptor alongside dimension tokens (see Dtype Annotations).
Core Concepts
Variables
When two parameters share a variable name, their sizes along that dimension must agree:
@check_tensor_shapes()
def add(x: ShapedTensor["n k"], y: ShapedTensor["n k"]) -> ShapedTensor["n k"]:
return x + y
Variable names can be any identifier not reserved by other rules (integers, *, ..., dtype tokens).
Integers as Shape Variables
int parameters are automatically promoted to shape variables, enabling dynamic shape constraints:
@check_tensor_shapes()
def take_k(x: ShapedTensor["n k"], k: int) -> ShapedTensor["n k"]:
return x[:, :k]
take_k(np.zeros((10, 4)), k=4) # passes — k=4 matches x.shape[1]
take_k(np.zeros((10, 4)), k=3) # raises TensorShapeAssertError
Disable this behaviour with @check_tensor_shapes(ints_to_variables=False).
Batch Dimensions
Use ... for an arbitrary number of leading dimensions:
@check_tensor_shapes()
def normalize(x: ShapedTensor["... d"]) -> ShapedTensor["... d"]:
return x / np.linalg.norm(x, axis=-1, keepdims=True)
Use a named batch dimension (...B) to enforce that multiple parameters share the same batch shape:
@check_tensor_shapes()
def bilinear(x: ShapedTensor["...B m k"], y: ShapedTensor["...B k n"]) -> ShapedTensor["...B m n"]:
return x @ y
Dtype Annotations
Add a dtype kind — and optionally a bit width — anywhere in the descriptor:
@check_tensor_shapes()
def safe_mean(x: ShapedTensor["float n k"]) -> ShapedTensor["float n"]:
return x.mean(axis=-1)
Supported dtype tokens:
| Token | Accepted dtypes |
|---|---|
bool |
boolean |
int, int8, int16, int32, int64 |
signed integer |
uint, uint8, uint16, uint32, uint64 |
unsigned integer |
integral |
any integer (signed or unsigned) |
float, float16, float32, float64 |
real floating-point |
complex, complex64, complex128 |
complex floating-point |
numeric |
any non-boolean numeric dtype |
These tokens are reserved and cannot be used as variable names.
Optional and Nested Annotations
Annotations can be arbitrarily nested in tuples or lists. Mark an optional tensor with | None:
@check_tensor_shapes()
def process(
x: tuple[ShapedTensor["n k"], ShapedTensor["n"] | None],
y: ShapedTensor["n 3"],
) -> ShapedTensor["n"]:
a, b = x
result = y.sum(axis=1)
if b is not None:
result = result + b
return result
NamedTuple classes are also supported — apply the decorator to the class itself.
API Reference
check_tensor_shapes
@check_tensor_shapes(
constraints=None,
ints_to_variables=True,
check_mode=None,
include_outer_variables=None,
disable_union_warning=False,
)
Decorator that enables shape checking for a function, class __init__, or NamedTuple class.
| Parameter | Type | Default | Description |
|---|---|---|---|
constraints |
list[str | Callable] |
None |
Extra constraints on shape variables. String expressions are evaluated (e.g. "a == 2 * b"); callables receive the variable dict and must return bool. Checked before and after the wrapped call. |
ints_to_variables |
bool |
True |
Promote int parameters to shape variables. |
check_mode |
"always" | "once" | "never" | None |
None |
Per-function check mode, overrides the global setting. |
include_outer_variables |
bool | None |
None |
Inherit shape variables from an enclosing check_tensor_shapes scope. Defaults to False for functions and True for NamedTuple instances. |
disable_union_warning |
bool |
False |
Suppress the warning about partially unsupported union types. |
set_global_check_mode
set_global_check_mode(mode: Literal["always", "once", "never"])
Set the global check mode for all @check_tensor_shapes-decorated functions. Per-function check_mode takes precedence when specified.
| Mode | Behaviour |
|---|---|
"always" |
Check every call (default) |
"once" |
Check each decorated function only on its first call |
"never" |
Disable all shape checks globally |
get_shape_variables
get_shape_variables(names: str) -> tuple[int | tuple[int, ...] | None, ...]
Return the current inferred values of shape variables. Must be called from inside a @check_tensor_shapes-wrapped function.
@check_tensor_shapes()
def my_func(x: ShapedTensor["n k 3"]):
n, k = get_shape_variables("n k")
print(f"n={n}, k={k}")
my_func(np.zeros((10, 9, 3))) # prints "n=10, k=9"
assert_shape_here
assert_shape_here(obj_or_shape, descriptor: str) -> None
Validate a tensor or shape tuple against a descriptor from inside a @check_tensor_shapes-wrapped function. Any new variables in the descriptor are registered for subsequent checks, including the function's return annotation.
@check_tensor_shapes()
def my_func(x: ShapedTensor["n k"]) -> ShapedTensor["n m"]:
y = some_operation(x)
assert_shape_here(y, "n m") # registers m; return annotation reuses it
return y
label_tensor
label_tensor(tensor, label: str | Iterable[str], overwrite: bool = False) -> tensor
Attach one or more labels to a tensor. Labels registered with register_label can appear in shape descriptors and are matched against the tensor's labels at call time.
from tensor_shape_assert import register_label, label_tensor
register_label("encoder_output")
z = label_tensor(encoder(x), "encoder_output")
@check_tensor_shapes()
def decode(z: ShapedTensor["encoder_output n d"]) -> ShapedTensor["n vocab"]:
...
register_label
register_label(label: str, constraint_fn: Callable[[array], bool] | None = None)
Register a custom label token for use in shape descriptors.
- If
constraint_fnisNone, the label is unconstrained: tensors must be explicitly tagged withlabel_tensorbefore being passed to a checked function. - If
constraint_fnis provided, the label behaves like a dtype annotation: any tensor whose descriptor contains this label is automatically checked by callingconstraint_fn(tensor). Constrained labels cannot be assigned vialabel_tensor.
Trace Utilities
Use the trace API to inspect how shape variables are inferred — useful for debugging:
from tensor_shape_assert import start_trace_recording, stop_trace_recording, trace_records_to_string
start_trace_recording()
my_func(np.zeros((10, 9, 3)))
records = stop_trace_recording()
print(trace_records_to_string(records))
| Function | Description |
|---|---|
start_trace_recording() |
Begin capturing per-parameter variable assignments |
stop_trace_recording() |
Stop capturing and return the list of TraceRecord objects |
trace_records_to_string(records) |
Format records as an indented, human-readable string |
Type-Safe Literal Syntax
For full static type-checker (MyPy, Pyright) compatibility, use ShapedLiteral and the pre-built framework aliases:
import torch
from typing import Literal as L
from tensor_shape_assert import check_tensor_shapes, ShapedTorchLiteral, ShapedLiteral
@check_tensor_shapes()
def my_func(
x: ShapedTorchLiteral[L["n k"]],
y: ShapedTorchLiteral[L["k m"]],
) -> ShapedLiteral[torch.Tensor, L["n m"]]:
return x @ y
| Alias | Array type |
|---|---|
ShapedTorchLiteral[L["..."]] |
torch.Tensor |
ShapedNumpyLiteral[L["..."]] |
numpy.ndarray |
ShapedLiteral[T, L["..."]] |
Any type T |
Extended Example
Tuple inputs, optional parameters, and batch dimensions together:
import torch
from tensor_shape_assert import check_tensor_shapes, ShapedTensor
@check_tensor_shapes()
def attention(
query: ShapedTensor["...B heads seq_q d"],
key_value: tuple[
ShapedTensor["...B heads seq_kv d"],
ShapedTensor["...B heads seq_kv d"],
],
mask: ShapedTensor["...B 1 seq_q seq_kv"] | None = None,
) -> ShapedTensor["...B heads seq_q d"]:
keys, values = key_value
scores = query @ keys.transpose(-2, -1) # (...B, heads, seq_q, seq_kv)
if mask is not None:
scores = scores + mask
weights = scores.softmax(dim=-1)
return weights @ values
# All of the following pass:
attention(
query=torch.zeros(2, 8, 16, 64),
key_value=(torch.zeros(2, 8, 32, 64), torch.zeros(2, 8, 32, 64)),
)
attention(
query=torch.zeros(4, 2, 8, 16, 64), # extra batch dim
key_value=(torch.zeros(4, 2, 8, 32, 64), torch.zeros(4, 2, 8, 32, 64)),
mask=torch.zeros(4, 2, 1, 16, 32),
)
Compatibility
tensor-shape-assert works with any array library whose objects expose a .shape property:
- PyTorch
- NumPy
- JAX
- TensorFlow
- Any library conforming to the Python Array API standard
License
This project is released under the MIT License. See LICENSE for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file tensor_shape_assert-0.4.2.tar.gz.
File metadata
- Download URL: tensor_shape_assert-0.4.2.tar.gz
- Upload date:
- Size: 23.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
aa4fbdf6fbe864266ae5a801f8ecf4272a699dd5b071560646fcef3bc5447e3e
|
|
| MD5 |
cca1f1191f8362bff2bf979d8661d5e5
|
|
| BLAKE2b-256 |
0825bb56d5ba8dafc7a48f4046f607bdd35eaf887a01173762b5c19b307d9870
|
Provenance
The following attestation bundles were made for tensor_shape_assert-0.4.2.tar.gz:
Publisher:
release.yml on leifvan/tensor-shape-assert
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
tensor_shape_assert-0.4.2.tar.gz -
Subject digest:
aa4fbdf6fbe864266ae5a801f8ecf4272a699dd5b071560646fcef3bc5447e3e - Sigstore transparency entry: 1244098825
- Sigstore integration time:
-
Permalink:
leifvan/tensor-shape-assert@e611db86ba29fee0f02b40d88ed3780865eb454d -
Branch / Tag:
refs/tags/v0.4.2 - Owner: https://github.com/leifvan
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@e611db86ba29fee0f02b40d88ed3780865eb454d -
Trigger Event:
release
-
Statement type:
File details
Details for the file tensor_shape_assert-0.4.2-py3-none-any.whl.
File metadata
- Download URL: tensor_shape_assert-0.4.2-py3-none-any.whl
- Upload date:
- Size: 21.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
080d88877009ca437b05ff2074e4c8a9cba0cb4f5a2f3f5401dd9872d65c0fff
|
|
| MD5 |
60fff83c6b02b64523f5b17c0b594b34
|
|
| BLAKE2b-256 |
23929930eb69f55e9f7f7d3971c4301110fe023ccbb95d1d346eeed41711e137
|
Provenance
The following attestation bundles were made for tensor_shape_assert-0.4.2-py3-none-any.whl:
Publisher:
release.yml on leifvan/tensor-shape-assert
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
tensor_shape_assert-0.4.2-py3-none-any.whl -
Subject digest:
080d88877009ca437b05ff2074e4c8a9cba0cb4f5a2f3f5401dd9872d65c0fff - Sigstore transparency entry: 1244099143
- Sigstore integration time:
-
Permalink:
leifvan/tensor-shape-assert@e611db86ba29fee0f02b40d88ed3780865eb454d -
Branch / Tag:
refs/tags/v0.4.2 - Owner: https://github.com/leifvan
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@e611db86ba29fee0f02b40d88ed3780865eb454d -
Trigger Event:
release
-
Statement type: