Skip to main content

ShapeGuard allows you to very succinctly assert the expected shapes of tensors in a dynamic, einsum inspired way.

Project description

ShapeGuard

ShapeGuard allows you to very succinctly assert the expected shapes of tensors in a dynamic, einsum inspired way

Turn this:

def batch_outer_product(x, y):
    # x has shape (batch, x_channels)
    # y has shape (batch, y_channels)
    # return has shape (batch, x_channels, y_channels)

    return x.unsqueeze(-1) * y.unsqueeze(-2)

Into this:

def batch_outer_product(x, y):        
    x.sg(("batch", "x_channels"))
    y.sg(("batch", "y_channels"))

    return (x.unsqueeze(-1) * y.unsqueeze(-2)).sg(("batch", "x_channels", "y_channels"))

Installation

pip install torch-shapeguard

Motivation

It’s easy to make bugs in ml. One particular rich source of bugs is due to the flexibility of the operators: a*b works whether a and b are vectors, scalar vector, vector vector, etc. Similarly .sum() will work regardless of the shape of your tensor. Since we're doing optimization whatever computation we end up performing, we can probably optimize it to work reasonably, even if it's not doing what we intended. So our algorithm might "work" even if we have bugs (just less well). This makes bugs super hard to discover.

The best way I’ve found to avoid bugs is to religiously check the shapes of all my tensors, all the time, so I end up spending a lot of time debugging and writing comments like #(bs, n_samples, z_size) all over the place.

So why not algorithmically check the shapes then? Well it gets ugly fast.

You have to add assert foo.shape == (bs, n_samples, x_size) everywhere, which essentially doubles your linecount and you have to define all your dimensional sizes (bs, etc.), which might vary across train/test, batches, etc. So I made a small helper that makes it much nicer. I call it ShapeGuard.

Usage

When you import shapeguard, It adds the sg method to torch.Tensor and torch.distributions.Distribution.

You can use the sg method like an assert:

def forward(self, x, y):
    x.sg("bchw")
    y.sg("by")

This will verify that x has 4 dimensions, y has 2 dimensions and that x and y have the same size in the first dimension 'b'.

If the assert passes, the tensor is returned. This means you can also chain it inline on results of operations:

z = f(x).sg("bnz").mean(axis=1).sg("bz")

If the assert fails it produces a nice error message:

AssertionError: expected 'b' to be 2 but was 4

If you want to verify an exact dimension you can pass an int as the shape e.g.

def forward(self, x, y):
    x.sg(("b", 1, "h", "w"))
    y.sg("by")

The special shape '*' is reserved for shapes that should not be asserted, e.g. x.sg("*chw") will assert all shapes except the first.

How it works

The first time sg is called for an unseen shape, the size of the tensor for that shape is saved in the ShapeGuard.shapes global dict. Subsequent calls are checked against this stored shape.

You can call ShapeGuard.reset(shape) to reset a specific shape. This can be useful if e.g. your batch size varies between runs. ShapeGuard.reset() resets all shapes.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_shapeguard-1.0.3.tar.gz (3.5 kB view details)

Uploaded Source

Built Distribution

torch_shapeguard-1.0.3-py3-none-any.whl (4.9 kB view details)

Uploaded Python 3

File details

Details for the file torch_shapeguard-1.0.3.tar.gz.

File metadata

  • Download URL: torch_shapeguard-1.0.3.tar.gz
  • Upload date:
  • Size: 3.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.10.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.7.10

File hashes

Hashes for torch_shapeguard-1.0.3.tar.gz
Algorithm Hash digest
SHA256 96e1b0ce02ffe339a6a9240057dac3d56c681d8e95eefa5dbcf582bca6ad6085
MD5 ec48ef327307999967f81f988d09c073
BLAKE2b-256 e95f0064f0dde973c08b50572e99f65f4c46feffba1c5dfc5a918c18cdcc08de

See more details on using hashes here.

File details

Details for the file torch_shapeguard-1.0.3-py3-none-any.whl.

File metadata

  • Download URL: torch_shapeguard-1.0.3-py3-none-any.whl
  • Upload date:
  • Size: 4.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.10.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.7.10

File hashes

Hashes for torch_shapeguard-1.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 3ab60f4e380cfa70d24a4ef459c2bb840958d70a100a751ad8f9accd76f12b42
MD5 e1afc52e73c32ac7948e0ae4794546bf
BLAKE2b-256 10871d652aeff4f5d8089f578f07915484e02fd9a8b9542184a0e4811d08849f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page