Annotates shapes of PyTorch Tensors using type annotation in Python3, and provides optional runtime shape validation.
Project description
Tensor Type
Annotates shapes of PyTorch Tensors using type annotation in Python3, and provides optional runtime shape validation.
This comes in very handy when debugging complex programs that manipulate huge torch.Tensor
s where shape (dimensions) vary widely and are hard to track down.
I got tired of writing tons of assert my_tensor.shape == (batch, channels, height, width)
over and over, so I made that utility, but then
I got tired of copy/pasting it into every new projects from my Gist of it, so here I finally made it a library that I can pip install everywhere.
Getting started
pip3 install tensor_type
tensor_type
only works with PyTorch, but that's only because I make the annotation type inherit from torch.Tensor
to satisfy static annotations.
Usage
from tensor_type import Tensor, Tensor3d, Tensor4d
import torch
# You can use the type in function's signatures
def my_obscure_function(x: Tensor4d) -> Tensor3d:
return x.sum(dim=3)/x.mean()
t = my_obscure_function(x=torch.rand(3,2,4,2))
# You can check the shape with an explicit assert
assert Tensor3d(t)
# Or you can check it with the .check() method which will produce a nicer error message
Tensor3d.check(t)
# Check specific shape
assert Tensor[3, 2, 4](t)
# This will match no matter the size of the second axis
assert Tensor[3, :, 4](t)
batch = 64
channels = 3
h, w = 256, 512
# You can statically annotate the shape like so
# This WILL NOT be checked at run time, it's just for clarity
my_tensor: Tensor[batch, channels, h, w] = load_images(...)
# You can assert it later like so:
assert Tensor[batch, channels, h, w](my_tensor)
# You can define new "types" like so:
ImageBatch = Tensor[64, 3, :, :]
# And then use the new type
assert ImageBatch(torch.rand(64, 3, 256, 256))
assert ImageBatch(torch.rand(64, 3, 512, 512))
assert not ImageBatch(torch.rand(64, 1, 256, 256))
Development
To install the latest version from Github, run:
git clone git@github.com:sam1902/tensor_type.git tensor_type
cd tensor_type
pip3 install -U .
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file tensor_type-0.1.0.tar.gz
.
File metadata
- Download URL: tensor_type-0.1.0.tar.gz
- Upload date:
- Size: 4.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.6.3 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.56.0 CPython/3.9.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 01c0f3a66daa5d0f81810e1e6ede91fa45fc3096d44d93d66ef33be8ee49e2e1 |
|
MD5 | 40d26b80a3f8a62d3b04fe171b41ffc1 |
|
BLAKE2b-256 | 4f5ca73971c3d34ea6ae8f35fdc9f1d32f2c0b0e85d13e7dd34f6543c88c77f9 |
File details
Details for the file tensor_type-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: tensor_type-0.1.0-py3-none-any.whl
- Upload date:
- Size: 5.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.6.3 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.56.0 CPython/3.9.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b18328776771fb88b2e042ded4cc83523758c1fda345de7d847d59016b331e4e |
|
MD5 | 556836f581ab2165b96a9c93183ae4e6 |
|
BLAKE2b-256 | 2e11516a7bb4704e64ecaade0b3af80647b213d13eafb01e74c9caf556b207e9 |