Skip to main content

Thin Plate Spline and Polyharmonic Spline implementation with PyTorch

Project description

Torch-TPS (Thin Plate Spline)

License PyPi Python Downloads Codecov Lint and Test

PyTorch implementation of the generalized Polyharmonic Spline interpolation (also known as Thin Plate Spline in 2D). It learns a smooth elastic mapping between two Euclidean spaces with support for:

  • Arbitrary input and output dimensions
  • Arbitrary spline order k
  • Optional regularization
  • Supports CPU and GPU parallelization

Useful for interpolation, deformation fields, and smooth non-linear regression.

For a NumPy implementation, see tps.

This implementation is much faster than the NumPy one, thanks to the cpu //. Using gpu seems not to be much faster for fitting (linear system solving), but is much faster to transform (as this is simply a matrix multiplication).

🚀 Install

Pip

$ pip install torch-tps

From source

git clone git@github.com:raphaelreme/torch-tps.git  # OR https://github.com/raphaelreme/torch-tps.git
cd torch-tps
pip install .

Getting started

import torch
from torch_tps import ThinPlateSpline

# Control points
X_train = torch.random.normal(0, 1, (800, 3))  # 800 points in R^3
Y_train = torch.random.normal(0, 1, (800, 2))  # Values for each point (800 values in R^2)

# New source points to interpolate
X_test = torch.random.normal(0, 1, (3000, 3))

# Initialize spline model (Regularization is controlled with alpha parameter)
tps = ThinPlateSpline(alpha=0.5)  # Use device="cuda" to switch to gpu

# Fit spline from control points
tps.fit(X_train, Y_train)

# Interpolate new points
Y_test = tps.transform(X_test)

Examples

See the example/ folder for scripts showing:

  • Interpolation in 1D, 2D, 3D
  • Arbitrary input and output dimensions
  • Image warping with elastic deformation

Image Warping

Example of increasing/decreasing/randomly deforming a dog's face using sparse control points.

Original Increased Decreased Random

Code: example/image_warping.py

🧠 Theory Summary

The model solves the regularized interpolation problem:

$$ min_f \sum_{i=1}^n (y_i - f(x_i))^2 + \int |\nabla^{\text{order}} f|_2^2 dx $$

With solution:

$$ f(x) = P(x) + \sum_{i=1}^n w_i G(|x - x_i|_2) $$

Where:

  • $G(r)$: radial basis function (RBF) (depends on order and the input dimension d)
  • $P(x)$: a polynomial of degree order - 1

Default kernel (TPS):

  • $G(r) = r^2 \log(r)$

General kernel:

  • $G(r) = r^{(2 \text{order} - d)} \text{ if d is odd}$
  • $G(r) = r^{(2\text{order} - d)} \log(r) \text{ otherwise}$

🔧 API

ThinPlateSpline(alpha=0.0, order=2, enforce_tps_kernel=False, device="cpu")

Creates a general polyharmonic spline interpolator (Default to TPS in 2D and natural cubic splines in 1D).

  • alpha (float): Regularization strength (default 0.0)
  • order (int): Spline order (default is 2 for TPS)
  • enforce_tps_kernel (bool): Force TPS kernel r^2 log r, even when mathematically suboptimal
  • device (torch.device): Use "cuda" to enable gpu computations. Default to "cpu".

.fit(X, Y)

Fits the model to control point pairs.

  • X: (n, d) input coordinates
  • Y: (n, v) target coordinates

Returns: self

.transform(X)

Applies the learned mapping to new input points.

  • X: (n', d) points

Returns: (n', v) interpolated values

License

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_tps-1.2.2.tar.gz (8.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_tps-1.2.2-py3-none-any.whl (10.1 kB view details)

Uploaded Python 3

File details

Details for the file torch_tps-1.2.2.tar.gz.

File metadata

  • Download URL: torch_tps-1.2.2.tar.gz
  • Upload date:
  • Size: 8.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for torch_tps-1.2.2.tar.gz
Algorithm Hash digest
SHA256 2be5e7b09853299418212400307c0a05c218e81dc8368d743ed7fa09af371103
MD5 69f3787e48c1b3578c18c26b80112476
BLAKE2b-256 0092f196a5ecf575fa7a9dfc060f900e1241828bfcfa4af8bb2e800814a7c68c

See more details on using hashes here.

File details

Details for the file torch_tps-1.2.2-py3-none-any.whl.

File metadata

  • Download URL: torch_tps-1.2.2-py3-none-any.whl
  • Upload date:
  • Size: 10.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for torch_tps-1.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 4d99488656c835fc731260497c9e34ce6eafaca623e4f9ea3dfe5e584f614ea3
MD5 cdb13dfa92008821b18079c2d56aa488
BLAKE2b-256 43c8bd4b081701cd37c4b1fb6914089d0f73bb6185f3888779e644a5dae4c2a1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page