Tensor-like types – with variadic shapes – that support both static and runtime type checking, and convenient parsing.
Project description
Phantom Tensors
Tensor types with variadic shapes, for any array-based library, that work with both static and runtime type checkers
This project is currently just a rough prototype! Inspired by: phantom-types
The goal of this project is to let users write tensor-like types with variadic shapes (via PEP 646) that are:
- Amendable to static type checking (without mypy plugins).
E.g., pyright can tell the difference between
Tensor[Batch, Channel]
andTensor[Batch, Feature]
- Useful for performing runtime checks of tensor types and shapes.
E.g., can validate -- at runtime -- that arrays of types
NDArray[A, B]
andNDArray[B, A]
indeed have transposed shapes with respect with each other. - Compatible with any array-based library (numpy, pytorch, xarray, cupy, mygrad, etc.)
E.g. A function annotated with
x: torch.Tensor
can be passedphantom_tensors.torch.Tensor[N, B, D]
. It is trivial to write custom phantom-tensor flavored types for any array-based library.
phantom_tensors.parse
makes it easy to declare shaped tensor types in a way that static type checkers understand, and that are validated at runtime:
from typing import NewType
import numpy as np
from phantom_tensors import parse
from phantom_tensors.numpy import NDArray
A = NewType("A", int)
B = NewType("B", int)
# static: declare that x is of type NDArray[A, B]
# declare that y is of type NDArray[B, A]
# runtime: check that shapes (2, 3) and (3, 2)
# match (A, B) and (B, A) pattern across
# tensors
x, y = parse(
(np.ones((2, 3)), NDArray[A, B]),
(np.ones((3, 2)), NDArray[B, A]),
)
x # static type checker sees: NDArray[A, B]
y # static type checker sees: NDArray[B, A]
Passing inconsistent types to parse
will result in a runtime validation error.
# Runtime: Raises `ParseError` A=10 and A=2 do not match
z, w = parse(
(np.ones((10, 3)), NDArray[A, B]),
(np.ones((3, 2)), NDArray[B, A]),
)
These shaped tensor types are amenable to static type checking:
from typing import Any
import numpy as np
from phantom_tensors import parse
from phantom_tensors.numpy import NDArray
from phantom_tensors.alphabet import A, B # these are just NewType(..., int) types
def func_on_2d(x: NDArray[Any, Any]): ...
def func_on_3d(x: NDArray[Any, Any, Any]): ...
def func_on_any_arr(x: np.ndarray): ...
# runtime: ensures shape of arr_3d matches (A, B, A) patterns
arr_3d = parse(np.ones((3, 5, 3)), NDArray[A, B, A])
func_on_2d(arr_3d) # static type checker: Error! # expects 2D arr, got 3D
func_on_3d(arr_3d) # static type checker: OK
func_on_any_arr(arr_3d) # static type checker: OK
Write easy-to-understand interfaces using common dimension names (or make up your own):
from phantom_tensors.torch import Tensor
from phantom_tensors.words import Batch, Embed, Vocab
def embedder(x: Tensor[Batch, Vocab]) -> Tensor[Batch, Embed]:
...
Using a runtime type checker, such as beartype or typeguard, in conjunction with phantom_tensors
means that the typed shape information will be validated at runtime across a function's inputs and outputs, whenever that function is called.
from typing import TypeVar, cast
from typing_extensions import assert_type
import torch as tr
from beartype import beartype
from phantom_tensors import dim_binding_scope, parse
from phantom_tensors.torch import Tensor
from phantom_tensors.alphabet import A, B, C
T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")
@dim_binding_scope
@beartype # <- adds runtime type checking to function's interfaces
def buggy_matmul(x: Tensor[T1, T2], y: Tensor[T2, T3]) -> Tensor[T1, T3]:
# This is the wrong operation!
# Will return shape-(T1, T1) tensor, not (T1, T3)
out = x @ x.T
# We lie to the static type checker to try to get away with it
return cast(Tensor[T1, T3], out)
x, y = parse(
(tr.ones(3, 4), Tensor[A, B]),
(tr.ones(4, 5), Tensor[B, C]),
)
# At runtime beartype raises:
# Function should return shape-(A, C) but returned shape-(A, A)
z = buggy_matmul(x, y) # Runtime validation error!
Installation
pip install phantom-tensors
typing-extensions
is the only strict dependency. Using features from phantom_tensors.torch(numpy)
requires that torch
(numpy
) is installed too.
Some Lower-Level Details and Features
Everything on display here is achieved using relatively minimal hacks (no mypy plugin necessary, no monkeypatching). Presently, torch.Tensor
and numpy.ndarray
are explicitly supported by phantom-tensors, but it is trivial to add support for other array-like classes.
Note that mypy does not support PEP 646 yet, but pyright does. You can run pyright on the following examples to see that they do, indeed type-check as expected!
Dimension-Binding Contexts
phantom_tensors.parse
validates inputs against types-with-shapes and performs type narrowing so that static type checkers are privy to the newly proven type information about those inputs. It performs inter-tensor shape consistency checks within a "dimension-binding context". Tensor-likes that are parsed simultaneously are automatically checked within a common dimension-binding context.
import numpy as np
import torch as tr
from phantom_tensors import parse
from phantom_tensors.alphabet import A, B, C
from phantom_tensors.numpy import NDArray
from phantom_tensors.torch import Tensor
t1, arr, t2 = parse(
# <- Runtime: enter dimension-binding context
(tr.rand(9, 2, 9), Tensor[B, A, B]), # <-binds A=2 & B=9
(np.ones((2,)), NDArray[A]), # <- checks A==2
(tr.rand(9), Tensor[B]), # <- checks B==9
) # <- Runtime: exit dimension-binding scope
# Statically: casts t1, arr, t2 to shape-typed Tensors
# static type checkers now see
# t1: Tensor[B, A, B]
# arr: NDArray[A]
# t2: Tensor[B]
w = parse(tr.rand(78), Tensor[A]); # <- binds A=78 within this context
As indicated above, the type-checker sees the shaped-tensor/array types. Additionally, these are subclasses of their rightful parents, so we can pass these to functions typed with vanilla torch.Tensor
and numpy.ndarry
annotations, and type checkers will be a-ok with that.
def vanilla_numpy(x: np.ndarray): ...
def vanilla_torch(x: tr.Tensor): ...
vanilla_numpy(arr) # type checker: OK
vanilla_torch(arr) # type checker: Error!
vanilla_torch(t1) # type checker: OK
Basic forms of runtime validation performed by parse
# runtime type checking
>>> parse(1, Tensor[A])
---------------------------------------------------------------------------
ParseError: Expected <class 'torch.Tensor'>, got: <class 'int'>
# dimensionality mismatch
>>> parse(tr.ones(3), Tensor[A, A, A])
---------------------------------------------------------------------------
ParseError: shape-(3,) doesn't match shape-type (A=?, A=?, A=?)
# unsatisfied shape pattern
>>> parse(tr.ones(1, 2), Tensor[A, A])
---------------------------------------------------------------------------
ParseError: shape-(1, 2) doesn't match shape-type (A=1, A=1)
# inconsistent dimension sizes across tensors
>>> x, y = parse(
... (tr.ones(1, 2), Tensor[A, B]),
... (tr.ones(4, 1), Tensor[B, A]),
... )
---------------------------------------------------------------------------
ParseError: shape-(4, 1) doesn't match shape-type (B=2, A=1)
To reiterate, parse
is able to compare shapes across multiple tensors by entering into a "dimension-binding scope".
One can enter into this context explicitly:
>>> from phantom_tensors import dim_binding_scope
>>> x = parse(np.zeros((2,)), NDArray[B]) # binds B=2
>>> y = parse(np.zeros((3,)), NDArray[B]) # binds B=3
>>> with dim_binding_scope:
... x = parse(np.zeros((2,)), NDArray[B]) # binds B=2
... y = parse(np.zeros((3,)), NDArray[B]) # raises!
---------------------------------------------------------------------------
ParseError: shape-(3,) doesn't match shape-type (B=2,)
Support for Literal
dimensions:
from typing import Literal as L
from phantom_tensors import parse
from phantom_tensors.torch import Tensor
import torch as tr
parse(tr.zeros(1, 3), Tensor[L[1], L[3]]) # static + runtime: OK
parse(tr.zeros(2, 3), Tensor[L[1], L[3]]) # # Runtime: ParseError - mismatch at dim 0
Support for Literal
dimensions and variadic shapes:
In Python 3.11 you can write shape types like Tensor[int, *Ts, int]
, where *Ts
represents 0 or more optional entries between two required dimensions. phantom-tensor supports this "unpack" dimension. In this README we opt for typing_extensions.Unpack[Ts]
instead of *Ts
for the sake of backwards compatibility.
from phantom_tensors import parse
from phantom_tensors.torch import Tensor
import torch as tr
from typing_extensions import Unpack as U, TypeVarTuple
Ts = TypeVarTuple("Ts")
# U[Ts] represents an arbitrary number of entries
parse(tr.ones(1, 3), Tensor[int, U[Ts], int) # static + runtime: OK
parse(tr.ones(1, 0, 0, 0, 3), Tensor[int, U[Ts], int]) # static + runtime: OK
parse(tr.ones(1, ), Tensor[int, U[Ts], int]) # Runtime: Not enough dimensions
Support for phantom types:
Supports phatom type dimensions (i.e. int
subclasses that override __isinstance__
checks):
from phantom_tensors import parse
from phantom_tensors.torch import Tensor
import torch as tr
from phantom import Phantom
class EvenOnly(int, Phantom, predicate=lambda x: x%2 == 0): ...
parse(tr.ones(1, 0), Tensor[int, EvenOnly]) # static return type: Tensor[int, EvenOnly]
parse(tr.ones(1, 2), Tensor[int, EvenOnly]) # static return type: Tensor[int, EvenOnly]
parse(tr.ones(1, 4), Tensor[int, EvenOnly]) # static return type: Tensor[int, EvenOnly]
parse(tr.ones(1, 3), Tensor[int, EvenOnly]) # runtime: ParseError (3 is not an even number)
Compatibility with Runtime Type Checkers
parse
is not the only way to perform runtime validation using phantom tensors – they work out of the box with 3rd party runtime type checkers like beartype! How is this possible?
...We do something tricky here! At, runtime Tensor[A, B]
actually returns a phantom type. This means that isinstance(arr, NDArray[A, B])
is, at runtime, actually performing isinstance(arr, PhantomNDArrayAB)
, which dynamically generated and is able to perform the type and shape checks.
Thanks to the ability to bind dimensions within a specified context, all beartype
needs to do is faithfully call isinstance(...)
within said context and we can have the inputs and ouputs of a phantom-tensor-annotated function get checked!
from typing import Any
from beartype import beartype # type: ignore
import pytest
import torch as tr
from phantom_tensors.alphabet import A, B, C
from phantom_tensors.torch import Tensor
from phantom_tensors import dim_binding_scope, parse
# @dim_binding_scope:
# ensures A, B, C consistent across all input/output tensor shapes
# within scope of function
@dim_binding_scope
@beartype # <-- adds isinstance checks on inputs & outputs
def matrix_multiply(x: Tensor[A, B], y: Tensor[B, C]) -> Tensor[A, C]:
a, _ = x.shape
_, c = y.shape
return parse(tr.rand(a, c), Tensor[A, C])
@beartype
def needs_vector(x: Tensor[Any]): ...
x, y = parse(
(tr.rand(3, 4), Tensor[A, B]),
(tr.rand(4, 5), Tensor[B, C]),
)
z = matrix_multiply(x, y)
z # type revealed: Tensor[A, C]
with pytest.raises(Exception):
# beartype raises error: input Tensor[A, C] doesn't match Tensor[A]
needs_vector(z) # <- pyright also raises an error!
with pytest.raises(Exception):
# beartype raises error: inputs Tensor[A, B], Tensor[A, B] don't match signature
matrix_multiply(x, x) # <- pyright also raises an error!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file phantom_tensors-0.3.0.tar.gz
.
File metadata
- Download URL: phantom_tensors-0.3.0.tar.gz
- Upload date:
- Size: 26.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | fc68cffb8d9fbd59da8e65007868a0e772965dda0a935a9903454764e13576f3 |
|
MD5 | 139a35c9075d6db77a9d288ff510eec2 |
|
BLAKE2b-256 | 93942f75060c5af9e16917b64cb5c4a337aa4f4b2f6b4ac716163719477630a6 |
File details
Details for the file phantom_tensors-0.3.0-py3-none-any.whl
.
File metadata
- Download URL: phantom_tensors-0.3.0-py3-none-any.whl
- Upload date:
- Size: 16.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d96cca6d810be6b5fef58f2decc93cccd8775d55048e56e3a24bc02936208fb0 |
|
MD5 | 859948159d917aafb94dc90451669aef |
|
BLAKE2b-256 | 9ec33e0f8c2b3668e8ef359a8e4bb48a493213a9d9b10df57df510db6936420f |