Skip to main content

Annotates shapes of PyTorch Tensors using type annotation in Python3, and provides optional runtime shape validation.

Project description

Tensor Type

PyPI version

Annotates shapes of PyTorch Tensors using type annotation in Python3, and provides optional runtime shape validation.

This comes in very handy when debugging complex programs that manipulate huge torch.Tensors where shape (dimensions) vary widely and are hard to track down.

I got tired of writing tons of assert my_tensor.shape == (batch, channels, height, width) over and over, so I made that utility, but then I got tired of copy/pasting it into every new projects from my Gist of it, so here I finally made it a library that I can pip install everywhere.

Getting started

pip3 install tensor_type

tensor_type only works with PyTorch, but that's only because I make the annotation type inherit from torch.Tensor to satisfy static annotations.

Usage

from tensor_type import Tensor, Tensor3d, Tensor4d
import torch

# You can use the type in function's signatures

def my_obscure_function(x: Tensor4d) -> Tensor3d:
    return x.sum(dim=3)/x.mean()

t = my_obscure_function(x=torch.rand(3,2,4,2))

# You can check the shape with an explicit assert
assert Tensor3d(t)

# Or you can check it with the .check() method which will produce a nicer error message
Tensor3d.check(t)

# Check specific shape
assert Tensor[3, 2, 4](t)

# This will match no matter the size of the second axis
assert Tensor[3, :, 4](t)

batch = 64
channels = 3
h, w = 256, 512

# You can statically annotate the shape like so
# This WILL NOT be checked at run time, it's just for clarity

my_tensor: Tensor[batch, channels, h, w] = load_images(...)

# You can assert it later like so:
assert Tensor[batch, channels, h, w](my_tensor)

# You can define new "types" like so:
ImageBatch = Tensor[64, 3, :, :]

# And then use the new type
assert ImageBatch(torch.rand(64, 3, 256, 256))
assert ImageBatch(torch.rand(64, 3, 512, 512))
assert not ImageBatch(torch.rand(64, 1, 256, 256))

Development

To install the latest version from Github, run:

git clone git@github.com:sam1902/tensor_type.git tensor_type
cd tensor_type
pip3 install -U .

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tensor_type-0.1.0.tar.gz (4.7 kB view details)

Uploaded Source

Built Distribution

tensor_type-0.1.0-py3-none-any.whl (5.1 kB view details)

Uploaded Python 3

File details

Details for the file tensor_type-0.1.0.tar.gz.

File metadata

  • Download URL: tensor_type-0.1.0.tar.gz
  • Upload date:
  • Size: 4.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.6.3 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.56.0 CPython/3.9.1

File hashes

Hashes for tensor_type-0.1.0.tar.gz
Algorithm Hash digest
SHA256 01c0f3a66daa5d0f81810e1e6ede91fa45fc3096d44d93d66ef33be8ee49e2e1
MD5 40d26b80a3f8a62d3b04fe171b41ffc1
BLAKE2b-256 4f5ca73971c3d34ea6ae8f35fdc9f1d32f2c0b0e85d13e7dd34f6543c88c77f9

See more details on using hashes here.

File details

Details for the file tensor_type-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: tensor_type-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 5.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.6.3 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.56.0 CPython/3.9.1

File hashes

Hashes for tensor_type-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b18328776771fb88b2e042ded4cc83523758c1fda345de7d847d59016b331e4e
MD5 556836f581ab2165b96a9c93183ae4e6
BLAKE2b-256 2e11516a7bb4704e64ecaade0b3af80647b213d13eafb01e74c9caf556b207e9

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page