Skip to main content

A PyTorch add-on for working with image mappings and displacement fields, including Spatial Transformers

Project description

torchfields

A PyTorch add-on for working with image mappings and displacement fields, including Spatial Transformers

Torchfields provides an abstraction that neatly encapsulates the functionality of displacement fields as used in Spatial Transformer Networks and Optical Flow Estimation.

Fields can be treated as normal PyTorch tensors for most purposes, and also include additional functionality for composing displacements and sampling from tensors.

Installation

To install torchfields simply do

pip install torchfields

Introduction

A displacement field represents a mapping or flow that indicates how an image should be warped.

It is essentially a spatial tensor containing displacement vectors at each pixel, where each displacement vector indicates the displacement distance and direction at that pixel.

Displacement field conventions

Units

The standard unit of displacement is a half-image, so a displacement vector of magnitude 2 means that the displacement distance is equal to the side length of the displaced image.

Note: This convention originates from the original Spatial Transformer Networks paper where such fields were presented as mappings, with -1 representing the left or top edge of the image, and +1 representing the right or bottom edge.

torchfields also supports seamlessly converting to and from units of pixels using the pixels() and from_pixels() functions.

Displacement direction

The most common way to warp an image by a displacement field is by sampling from it at the points pointed to by the field vectors. This is often referred to as the Eulerian or pull convention, since the vectors in the field point to the locations from which the image should be pulled. This is achieved by calling the sample() function (which in fact wraps PyTorch's built-in grid_sample(), while converting the conventions as necessary).

An alternative way to warp an image by a displacement field is by sending each pixel of the image along the corresponding displacement vector to its new location. This is referred to as the Lagrangian or push convention, since the vectors of the field indicate where an image pixel should be pushed to. This direction, while seemingly intuitive, is much less straight-forward to implement, since there is no definitive way to handle the discretization (for instance, what to do when the destinations are not whole pixel coordinates, when two sources map to the same destination, and when nothing maps into a destination pixel). The solution for warping in the Lagrangian direction is to first invert the field using inverse(), and then warp the image normally using sample().

To read more about the two ways to describe flow fields, see the Wikipedia article on the subject.

Relationship to PyTorch tensors

Displacement fields inherit from torch.Tensor, so all functionality from PyTorch tensors also works with displacement fields. That is, any PyTorch function that accepts a torch.Tensor type will also implicitly accept a torchfields displacement field.

Furthermore, the module installs itself (through monkey patching) as

torch.Field

mirroring the torch.Tensor module, and all the functionality of the torchfields package can be conveniently accessed through that shortcut. This shortcut gets activated at the first import (using import torchfields).

Note, however, that the torchfields package is neither endorsed by nor maintained by the PyTorch developer community, and is instead a separate project maintained by researchers at Princeton University.

Tutorial

To learn more and get started with using torchfields check out the tutorial.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for torchfields, version 0.0.5
Filename, size File type Python version Upload date Hashes
Filename, size torchfields-0.0.5-py3-none-any.whl (20.5 kB) File type Wheel Python version py3 Upload date Hashes View
Filename, size torchfields-0.0.5.tar.gz (20.9 kB) File type Source Python version None Upload date Hashes View

Supported by

Pingdom Pingdom Monitoring Google Google Object Storage and Download Analytics Sentry Sentry Error logging AWS AWS Cloud computing DataDog DataDog Monitoring Fastly Fastly CDN DigiCert DigiCert EV certificate StatusPage StatusPage Status page