Thin Plate Spline and Polyharmonic Spline implementation with PyTorch
Project description
Torch-TPS (Thin Plate Spline)
PyTorch implementation of the generalized Polyharmonic Spline interpolation (also known as Thin Plate Spline in 2D). It learns a smooth elastic mapping between two Euclidean spaces with support for:
- Arbitrary input and output dimensions
- Arbitrary spline order
k - Optional regularization
- Supports CPU and GPU parallelization
Useful for interpolation, deformation fields, and smooth non-linear regression.
For a NumPy implementation, see tps.
This implementation is much faster than the NumPy one, thanks to the cpu //. Using gpu seems not to be much faster for fitting (linear system solving), but is much faster to transform (as this is simply a matrix multiplication).
🚀 Install
Pip
$ pip install torch-tps
From source
git clone git@github.com:raphaelreme/torch-tps.git # OR https://github.com/raphaelreme/torch-tps.git
cd torch-tps
pip install .
Getting started
import torch
from torch_tps import ThinPlateSpline
# Control points
X_train = torch.random.normal(0, 1, (800, 3)) # 800 points in R^3
Y_train = torch.random.normal(0, 1, (800, 2)) # Values for each point (800 values in R^2)
# New source points to interpolate
X_test = torch.random.normal(0, 1, (3000, 3))
# Initialize spline model (Regularization is controlled with alpha parameter)
tps = ThinPlateSpline(alpha=0.5) # Use device="cuda" to switch to gpu
# Fit spline from control points
tps.fit(X_train, Y_train)
# Interpolate new points
Y_test = tps.transform(X_test)
Examples
See the example/ folder for scripts showing:
- Interpolation in 1D, 2D, 3D
- Arbitrary input and output dimensions
- Image warping with elastic deformation
Image Warping
Example of increasing/decreasing/randomly deforming a dog's face using sparse control points.
Code: example/image_warping.py
🧠 Theory Summary
The model solves the regularized interpolation problem:
$$ min_f \sum_{i=1}^n (y_i - f(x_i))^2 + \int |\nabla^{\text{order}} f|_2^2 dx $$
With solution:
$$ f(x) = P(x) + \sum_{i=1}^n w_i G(|x - x_i|_2) $$
Where:
- $G(r)$: radial basis function (RBF) (depends on
orderand the input dimensiond) - $P(x)$: a polynomial of degree
order - 1
Default kernel (TPS):
- $G(r) = r^2 \log(r)$
General kernel:
- $G(r) = r^{(2 \text{order} - d)} \text{ if d is odd}$
- $G(r) = r^{(2\text{order} - d)} \log(r) \text{ otherwise}$
🔧 API
ThinPlateSpline(alpha=0.0, order=2, enforce_tps_kernel=False, device="cpu")
Creates a general polyharmonic spline interpolator (Default to TPS in 2D and natural cubic splines in 1D).
- alpha (float): Regularization strength (default 0.0)
- order (int): Spline order (default is 2 for TPS)
- enforce_tps_kernel (bool): Force TPS kernel r^2 log r, even when mathematically suboptimal
- device (torch.device): Use "cuda" to enable gpu computations. Default to "cpu".
.fit(X, Y)
Fits the model to control point pairs.
- X:
(n, d)input coordinates - Y:
(n, v)target coordinates
Returns: self
.transform(X)
Applies the learned mapping to new input points.
- X:
(n', d)points
Returns: (n', v) interpolated values
License
MIT License
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_tps-1.2.2.tar.gz.
File metadata
- Download URL: torch_tps-1.2.2.tar.gz
- Upload date:
- Size: 8.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.25
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2be5e7b09853299418212400307c0a05c218e81dc8368d743ed7fa09af371103
|
|
| MD5 |
69f3787e48c1b3578c18c26b80112476
|
|
| BLAKE2b-256 |
0092f196a5ecf575fa7a9dfc060f900e1241828bfcfa4af8bb2e800814a7c68c
|
File details
Details for the file torch_tps-1.2.2-py3-none-any.whl.
File metadata
- Download URL: torch_tps-1.2.2-py3-none-any.whl
- Upload date:
- Size: 10.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.25
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4d99488656c835fc731260497c9e34ce6eafaca623e4f9ea3dfe5e584f614ea3
|
|
| MD5 |
cdb13dfa92008821b18079c2d56aa488
|
|
| BLAKE2b-256 |
43c8bd4b081701cd37c4b1fb6914089d0f73bb6185f3888779e644a5dae4c2a1
|