Neural network sampling utilities.
Project description
Lightweight pytorch functions for neural network featuremap sampling.
WARNING: API is not yet stable. API subject to change!
Introduction
Sampling neural network featuremaps at explicit coordinates has become more and more common with popular developments like:
Learning Continuous Image Representation with Local Implicit Image Function
NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis
PyTorch provides the tools necessary that to sample coordinates, but they result in a large amount of error-prone code. TorchSample intends to make it simple so you can focus on other parts of the model.
Usage
Installation
Requires python >=3.8 Install torchsample via pip:
pip install torchsample
Or, if you want to install the nightly version:
pip install git+https://github.com/BrianPugh/torchsample.git@main
Training
A common scenario is to randomly sample points from a featmap and from the ground truth.
import torchsample as ts
b, c, h, w = batch["image"].shape
coords = ts.coord.rand(b, 4096, 2) # (b, 4096, 2) where the last dim is (x, y)
featmap = feature_extractor(batch["image"]) # (b, feat, h, w)
sampled = ts.sample(coords, featmap) # (b, 4096, feat)
gt_sample = ts.sample(coords, batch["gt"])
Inference
During inference, a comprehensive query of the network to form a complete image is common.
import torch
import torchsample as ts
b, c, h, w = batch["image"].shape
coords = ts.coord.full_like(batch["image"])
featmap = encoder(batch["image"]) # (b, feat, h, w)
feat_sampled = ts.sample(coords, featmap) # (b, h, w, c)
output = model(featmap) # (b, h, w, pred)
output = output.permute(0, 3, 1, 2)
Positional Encoding
Common positional encoding schemes are available.
import torchsample as ts
coords = ts.coord.rand(b, 4096, 2)
pos_enc = ts.encoding.gamma(coords)
A common task it concatenating the positional encoding to sampled values. You can do this by passing a callable into ts.sample:
import torchsample as ts
encoder = ts.encoding.Gamma()
sampled = ts.sample(coords, featmap, encoder=encoder)
Models
torchsample has some common builtin models:
import torchsample as ts
# Properly handles (..., feat) tensors.
model = ts.models.MLP(256, 256, 512, 512, 1024, 1024, 1)
Design Decisions
align_corners=False by default (same as Pytorch). You should probably not touch it; explanation here.
Everything is in normalized coordinates [-1, 1] by default.
Coordinates are always in order (x, y, ...).
Whenever a size is given, it will be in (w, h) order; i.e. matches coordinate order. It makes implementation simpler and a consistent rule helps prevent bugs.
When coords is a function argument, it comes first.
Simple wrapper functions (like ts.coord.rand) are provided to make the intentions of calling code more clear.
Try and mimic native pytorch and torchvision interfaces as much as possible.
Try and make the common-usecase as simple and intuitive as possible.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file torchsample-0.1.0.tar.gz
.
File metadata
- Download URL: torchsample-0.1.0.tar.gz
- Upload date:
- Size: 393.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.9.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 022b4be1e24dfbee27922fd73d04096e7430a6ab97c68cc8d550072fe8bff28b |
|
MD5 | d41ea125bc8231cc318b3a402b525c5b |
|
BLAKE2b-256 | 34ba02c091ea10edbd66f44f64afeebd2ea8ac932f7185f9162e922d41851508 |