Skip to main content

Elegant dimensions for a more civilized age

Project description

torchsaber: Elegant Dimensions for a more Civilized Age

It is a period of civil war. Rebel spaceships have won their first victory against the awesome RuntimeError (experimental support for Named Tensors in PyTorch).

Motivation

How often have you written or seen code like this?

images = torch.randn(10, 3, 64, 64)
flipped = images.transpose(2, 3)
cropped = flipped[:, :, x:x+32, y:y+32]
flatten = cropped.view(-1, 32 * 32)
gray = flatten.sum(dim=2) / 3.

or gotten a mysterious error like

RuntimeError: The size of tensor x (2) must match the size of tensor y (4) at non-singleton dimension 2

I don't care what universe you're from, that's got to hurt your eyes. What do all those numbers mean? What order are the dimensions in by the end? Is dim=2 the right summation to average the channels? Who knows? Unnamed dimensions lead to anger; anger leads to bugs; bugs lead to suffering.

But, a ray of hope! PyTorch has released experimental support for named tensors: that is, tensors whose dimensions have names rather than simply numeric indices. It's wonderful news, but it still has some rough edges. For example, you have to refer to names using hardcoded string literals everywhere; a typo (e.g. 'hieght') can break your code in unexpected ways that aren't caught at runtime.

torchsaber is a minimal syntactic sugar for named tensors. Its goal is to give you power (unlimited power!) by allowing all manipulation of dimensions to be done by name rather than by numeric indices or hardcoded string literals. Dimensions and their manipulations are first-class objects that interface cleanly with PyTorch's user-facing API. For example, the above code snippet becomes:

from torchsaber import dims
batch, channel, height, width, features =
	dims('batch', 'channel', 'height', 'width', 'features')

images = torch.randn | batch(10) + channel(3) + height(64) + width(64)
flipped = images.permute(~width, ~height)
cropped = flipped | height[:32] + width[:32]
flatten = cropped.flatten([~height, ~width], ~features)
gray = flatten.sum(dim=~channel) / 3.

By "minimal" I mean the entire implementation is around 100 lines of code.

Can I use it?

Sure! The easiest way to learn is to read the big comment at the top of torchsaber.py. It's a literate doctest! Then pip install torch torchvision torchsaber and from torchsaber import dims and enjoy. torchsaber tries to be compatible with the named tensor docs and should work with operators supported by named tensors.

However, because named tensors are experimental, so is torchsaber. The real goal of the project is to provoke some discussion around human-friendly designs for the tensor programs of the future.

(Some!) references

Here's a non-exhaustive list of prior work (many of these have their own bibliographies you can follow)…

More to say, have you?

Open an issue, file a PR, send me an email!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchsaber-1.0.2.tar.gz (5.4 kB view hashes)

Uploaded Source

Built Distribution

torchsaber-1.0.2-py3-none-any.whl (5.5 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page