Skip to main content

Neural network sampling utilities.

Project description

https://raw.githubusercontent.com/BrianPugh/torchsample/main/assets/banner-white-bg-512w.png

GHA Status Coverage Documentation Status

Lightweight pytorch functions for neural network featuremap sampling.

WARNING: API is not yet stable. API subject to change!

Introduction

Sampling neural network featuremaps at explicit coordinates has become more and more common with popular developments like:

PyTorch provides the tools necessary that to sample coordinates, but they result in a large amount of error-prone code. TorchSample intends to make it simple so you can focus on other parts of the model.

Usage

Installation

Requires python >=3.8 Install torchsample via pip:

pip install torchsample

Or, if you want to install the nightly version:

pip install git+https://github.com/BrianPugh/torchsample.git@main

Training

A common scenario is to randomly sample points from a featmap and from the ground truth.

import torchsample as ts

b, c, h, w = batch["image"].shape
coords = ts.coord.rand(b, 4096, 2)  # (b, 4096, 2) where the last dim is (x, y)

featmap = feature_extractor(batch["image"])  # (b, feat, h, w)
sampled = ts.sample(coords, featmap)  # (b, 4096, feat)
gt_sample = ts.sample(coords, batch["gt"])

Inference

During inference, a comprehensive query of the network to form a complete image is common.

import torch
import torchsample as ts

b, c, h, w = batch["image"].shape
coords = ts.coord.full_like(batch["image"])
featmap = encoder(batch["image"])  # (b, feat, h, w)
feat_sampled = ts.sample(coords, featmap)  # (b, h, w, c)
output = model(featmap)  # (b, h, w, pred)
output = output.permute(0, 3, 1, 2)

Positional Encoding

Common positional encoding schemes are available.

import torchsample as ts

coords = ts.coord.rand(b, 4096, 2)
pos_enc = ts.encoding.gamma(coords)

A common task it concatenating the positional encoding to sampled values. You can do this by passing a callable into ts.sample:

import torchsample as ts

encoder = ts.encoding.Gamma()
sampled = ts.sample(coords, featmap, encoder=encoder)

Models

torchsample has some common builtin models:

import torchsample as ts

# Properly handles (..., feat) tensors.
model = ts.models.MLP(256, 256, 512, 512, 1024, 1024, 1)

Design Decisions

  • align_corners=False by default (same as Pytorch). You should probably not touch it; explanation here.

  • Everything is in normalized coordinates [-1, 1] by default.

  • Coordinates are always in order (x, y, ...).

  • Whenever a size is given, it will be in (w, h) order; i.e. matches coordinate order. It makes implementation simpler and a consistent rule helps prevent bugs.

  • When coords is a function argument, it comes first.

  • Simple wrapper functions (like ts.coord.rand) are provided to make the intentions of calling code more clear.

  • Try and mimic native pytorch and torchvision interfaces as much as possible.

  • Try and make the common-usecase as simple and intuitive as possible.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchsample-0.1.0.tar.gz (393.6 kB view details)

Uploaded Source

File details

Details for the file torchsample-0.1.0.tar.gz.

File metadata

  • Download URL: torchsample-0.1.0.tar.gz
  • Upload date:
  • Size: 393.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for torchsample-0.1.0.tar.gz
Algorithm Hash digest
SHA256 022b4be1e24dfbee27922fd73d04096e7430a6ab97c68cc8d550072fe8bff28b
MD5 d41ea125bc8231cc318b3a402b525c5b
BLAKE2b-256 34ba02c091ea10edbd66f44f64afeebd2ea8ac932f7185f9162e922d41851508

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page