Skip to main content

Tensor-like types – with variadic shapes – that support both static and runtime type checking, and convenient parsing.

Project description

Phantom Tensors

Tensor types with variadic shapes, for any array-based library, that work with both static and runtime type checkers

PyPI Python version support

This project is currently just a rough prototype! Inspired by: phantom-types

The goal of this project is to let users write tensor-like types with variadic shapes (via PEP 646) that are:

  • Amendable to static type checking (without mypy plugins).

    E.g., pyright can tell the difference between Tensor[Batch, Channel] and Tensor[Batch, Feature]

  • Useful for performing runtime checks of tensor types and shapes.

    E.g., can validate -- at runtime -- that arrays of types NDArray[A, B] and NDArray[B, A] indeed have transposed shapes with respect with each other.

  • Compatible with any array-based library (numpy, pytorch, xarray, cupy, mygrad, etc.)

    E.g. A function annotated with x: torch.Tensor can be passed phantom_tensors.torch.Tensor[N, B, D]. It is trivial to write custom phantom-tensor flavored types for any array-based library.

phantom_tensors.parse makes it easy to declare shaped tensor types in a way that static type checkers understand, and that are validated at runtime:

from typing import NewType

import numpy as np

from phantom_tensors import parse
from phantom_tensors.numpy import NDArray

A = NewType("A", int)
B = NewType("B", int)

# static: declare that x is of type NDArray[A, B]
#         declare that y is of type NDArray[B, A]
# runtime: check that shapes (2, 3) and (3, 2)
#          match (A, B) and (B, A) pattern across
#          tensors
x, y = parse(
    (np.ones((2, 3)), NDArray[A, B]),
    (np.ones((3, 2)), NDArray[B, A]),
)

x  # static type checker sees: NDArray[A, B]
y  # static type checker sees: NDArray[B, A]

Passing inconsistent types to parse will result in a runtime validation error.

# Runtime: Raises `ParseError` A=10 and A=2 do not match
z, w = parse(
    (np.ones((10, 3)), NDArray[A, B]),
    (np.ones((3, 2)), NDArray[B, A]),
)

These shaped tensor types are amenable to static type checking:

from typing import Any

import numpy as np

from phantom_tensors import parse
from phantom_tensors.numpy import NDArray
from phantom_tensors.alphabet import A, B  # these are just NewType(..., int) types

def func_on_2d(x: NDArray[Any, Any]): ...
def func_on_3d(x: NDArray[Any, Any, Any]): ...
def func_on_any_arr(x: np.ndarray): ...

# runtime: ensures shape of arr_3d matches (A, B, A) patterns
arr_3d = parse(np.ones((3, 5, 3)), NDArray[A, B, A])

func_on_2d(arr_3d)  # static type checker: Error!  # expects 2D arr, got 3D

func_on_3d(arr_3d)  # static type checker: OK
func_on_any_arr(arr_3d)  # static type checker: OK

Write easy-to-understand interfaces using common dimension names (or make up your own):

from phantom_tensors.torch import Tensor
from phantom_tensors.words import Batch, Embed, Vocab

def embedder(x: Tensor[Batch, Vocab]) -> Tensor[Batch, Embed]:
    ...

Using a runtime type checker, such as beartype or typeguard, in conjunction with phantom_tensors means that the typed shape information will be validated at runtime across a function's inputs and outputs, whenever that function is called.

from typing import TypeVar, cast
from typing_extensions import assert_type

import torch as tr
from beartype import beartype

from phantom_tensors import dim_binding_scope, parse
from phantom_tensors.torch import Tensor
from phantom_tensors.alphabet import A, B, C

T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")


@dim_binding_scope
@beartype  # <- adds runtime type checking to function's interfaces
def buggy_matmul(x: Tensor[T1, T2], y: Tensor[T2, T3]) -> Tensor[T1, T3]:
    # This is the wrong operation!
    # Will return shape-(T1, T1) tensor, not (T1, T3)
    out = x @ x.T
    
    # We lie to the static type checker to try to get away with it
    return cast(Tensor[T1, T3], out)

x, y = parse(
    (tr.ones(3, 4), Tensor[A, B]),
    (tr.ones(4, 5), Tensor[B, C]),
)

# At runtime beartype raises:
#   Function should return shape-(A, C) but returned shape-(A, A)
z = buggy_matmul(x, y)  # Runtime validation error!

Installation

pip install phantom-tensors

typing-extensions is the only strict dependency. Using features from phantom_tensors.torch(numpy) requires that torch(numpy) is installed too.

Some Lower-Level Details and Features

Everything on display here is achieved using relatively minimal hacks (no mypy plugin necessary, no monkeypatching). Presently, torch.Tensor and numpy.ndarray are explicitly supported by phantom-tensors, but it is trivial to add support for other array-like classes.

Note that mypy does not support PEP 646 yet, but pyright does. You can run pyright on the following examples to see that they do, indeed type-check as expected!

Dimension-Binding Contexts

phantom_tensors.parse validates inputs against types-with-shapes and performs type narrowing so that static type checkers are privy to the newly proven type information about those inputs. It performs inter-tensor shape consistency checks within a "dimension-binding context". Tensor-likes that are parsed simultaneously are automatically checked within a common dimension-binding context.

import numpy as np
import torch as tr

from phantom_tensors import parse
from phantom_tensors.alphabet import A, B, C
from phantom_tensors.numpy import NDArray
from phantom_tensors.torch import Tensor

t1, arr, t2 = parse(
    # <- Runtime: enter dimension-binding context
    (tr.rand(9, 2, 9), Tensor[B, A, B]),  # <-binds A=2 & B=9
    (np.ones((2,)), NDArray[A]),  # <- checks A==2
    (tr.rand(9), Tensor[B]),  # <- checks B==9
)  # <- Runtime: exit dimension-binding scope 
   #    Statically: casts t1, arr, t2 to shape-typed Tensors

# static type checkers now see
# t1: Tensor[B, A, B] 
# arr: NDArray[A]
# t2: Tensor[B]

w = parse(tr.rand(78), Tensor[A]);  # <- binds A=78 within this context

As indicated above, the type-checker sees the shaped-tensor/array types. Additionally, these are subclasses of their rightful parents, so we can pass these to functions typed with vanilla torch.Tensor and numpy.ndarry annotations, and type checkers will be a-ok with that.

def vanilla_numpy(x: np.ndarray): ...
def vanilla_torch(x: tr.Tensor): ...

vanilla_numpy(arr)  # type checker: OK
vanilla_torch(arr)  # type checker: Error! 
vanilla_torch(t1)  # type checker: OK 

Basic forms of runtime validation performed by parse

# runtime type checking
>>> parse(1, Tensor[A])
---------------------------------------------------------------------------
ParseError: Expected <class 'torch.Tensor'>, got: <class 'int'>

# dimensionality mismatch
>>> parse(tr.ones(3), Tensor[A, A, A])
---------------------------------------------------------------------------
ParseError: shape-(3,) doesn't match shape-type (A=?, A=?, A=?)

# unsatisfied shape pattern
>>> parse(tr.ones(1, 2), Tensor[A, A])
---------------------------------------------------------------------------
ParseError: shape-(1, 2) doesn't match shape-type (A=1, A=1)

# inconsistent dimension sizes across tensors
>>> x, y = parse(
...     (tr.ones(1, 2), Tensor[A, B]),
...     (tr.ones(4, 1), Tensor[B, A]),
... )

---------------------------------------------------------------------------
ParseError: shape-(4, 1) doesn't match shape-type (B=2, A=1)

To reiterate, parse is able to compare shapes across multiple tensors by entering into a "dimension-binding scope". One can enter into this context explicitly:

>>> from phantom_tensors import dim_binding_scope

>>> x = parse(np.zeros((2,)), NDArray[B])  # binds B=2
>>> y = parse(np.zeros((3,)), NDArray[B])  # binds B=3
>>> with dim_binding_scope:
...     x = parse(np.zeros((2,)), NDArray[B])  # binds B=2
...     y = parse(np.zeros((3,)), NDArray[B])  # raises!
---------------------------------------------------------------------------
ParseError: shape-(3,) doesn't match shape-type (B=2,)

Support for Literal dimensions:

from typing import Literal as L

from phantom_tensors import parse
from phantom_tensors.torch import Tensor

import torch as tr

parse(tr.zeros(1, 3), Tensor[L[1], L[3]])  # static + runtime: OK
parse(tr.zeros(2, 3), Tensor[L[1], L[3]])  #  # Runtime: ParseError - mismatch at dim 0

Support for Literal dimensions and variadic shapes:

In Python 3.11 you can write shape types like Tensor[int, *Ts, int], where *Ts represents 0 or more optional entries between two required dimensions. phantom-tensor supports this "unpack" dimension. In this README we opt for typing_extensions.Unpack[Ts] instead of *Ts for the sake of backwards compatibility.

from phantom_tensors import parse
from phantom_tensors.torch import Tensor

import torch as tr
from typing_extensions import Unpack as U, TypeVarTuple

Ts = TypeVarTuple("Ts")

# U[Ts] represents an arbitrary number of entries
parse(tr.ones(1, 3), Tensor[int, U[Ts], int)  # static + runtime: OK
parse(tr.ones(1, 0, 0, 0, 3), Tensor[int, U[Ts], int])  # static + runtime: OK

parse(tr.ones(1, ), Tensor[int, U[Ts], int])  # Runtime: Not enough dimensions

Support for phantom types:

Supports phatom type dimensions (i.e. int subclasses that override __isinstance__ checks):

from phantom_tensors import parse
from phantom_tensors.torch import Tensor

import torch as tr
from phantom import Phantom

class EvenOnly(int, Phantom, predicate=lambda x: x%2 == 0): ...

parse(tr.ones(1, 0), Tensor[int, EvenOnly])  # static return type: Tensor[int, EvenOnly] 
parse(tr.ones(1, 2), Tensor[int, EvenOnly])  # static return type: Tensor[int, EvenOnly] 
parse(tr.ones(1, 4), Tensor[int, EvenOnly])  # static return type: Tensor[int, EvenOnly] 

parse(tr.ones(1, 3), Tensor[int, EvenOnly])  # runtime: ParseError (3 is not an even number)

Compatibility with Runtime Type Checkers

parse is not the only way to perform runtime validation using phantom tensors – they work out of the box with 3rd party runtime type checkers like beartype! How is this possible?

...We do something tricky here! At, runtime Tensor[A, B] actually returns a phantom type. This means that isinstance(arr, NDArray[A, B]) is, at runtime, actually performing isinstance(arr, PhantomNDArrayAB), which dynamically generated and is able to perform the type and shape checks.

Thanks to the ability to bind dimensions within a specified context, all beartype needs to do is faithfully call isinstance(...) within said context and we can have the inputs and ouputs of a phantom-tensor-annotated function get checked!

from typing import Any

from beartype import beartype  # type: ignore
import pytest
import torch as tr

from phantom_tensors.alphabet import A, B, C
from phantom_tensors.torch import Tensor
from phantom_tensors import dim_binding_scope, parse

# @dim_binding_scope:
#   ensures A, B, C consistent across all input/output tensor shapes
#   within scope of function
@dim_binding_scope 
@beartype  # <-- adds isinstance checks on inputs & outputs
def matrix_multiply(x: Tensor[A, B], y: Tensor[B, C]) -> Tensor[A, C]:
    a, _ = x.shape
    _, c = y.shape
    return parse(tr.rand(a, c), Tensor[A, C])

@beartype
def needs_vector(x: Tensor[Any]): ...

x, y = parse(
    (tr.rand(3, 4), Tensor[A, B]),
    (tr.rand(4, 5), Tensor[B, C]),
)

z = matrix_multiply(x, y)
z  # type revealed: Tensor[A, C]

with pytest.raises(Exception):
    # beartype raises error: input Tensor[A, C] doesn't match Tensor[A]
    needs_vector(z)  # <- pyright also raises an error!

with pytest.raises(Exception):
    # beartype raises error: inputs Tensor[A, B], Tensor[A, B] don't match signature
    matrix_multiply(x, x)  # <- pyright also raises an error!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

phantom_tensors-0.3.0.tar.gz (26.3 kB view details)

Uploaded Source

Built Distribution

phantom_tensors-0.3.0-py3-none-any.whl (16.7 kB view details)

Uploaded Python 3

File details

Details for the file phantom_tensors-0.3.0.tar.gz.

File metadata

  • Download URL: phantom_tensors-0.3.0.tar.gz
  • Upload date:
  • Size: 26.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.4

File hashes

Hashes for phantom_tensors-0.3.0.tar.gz
Algorithm Hash digest
SHA256 fc68cffb8d9fbd59da8e65007868a0e772965dda0a935a9903454764e13576f3
MD5 139a35c9075d6db77a9d288ff510eec2
BLAKE2b-256 93942f75060c5af9e16917b64cb5c4a337aa4f4b2f6b4ac716163719477630a6

See more details on using hashes here.

File details

Details for the file phantom_tensors-0.3.0-py3-none-any.whl.

File metadata

File hashes

Hashes for phantom_tensors-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d96cca6d810be6b5fef58f2decc93cccd8775d55048e56e3a24bc02936208fb0
MD5 859948159d917aafb94dc90451669aef
BLAKE2b-256 9ec33e0f8c2b3668e8ef359a8e4bb48a493213a9d9b10df57df510db6936420f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page