Skip to main content

Radial Basis Function Interpolation in PyTorch

Project description

torchrbf: Radial Basis Function Interpolation in PyTorch

This is a PyTorch module for Radial Basis Function (RBF) Interpolation, which is translated from SciPy's implemenation. This implementation benefits from GPU acceleration, making it significantly faster and more suitable for larger interpolation problems.

Installation

pip install torchrbf

The only dependencies are PyTorch and NumPy. If you want to run the tests and benchmarks, you also need SciPy installed.

A note on numerical precision

If you are using TF32, you may experience numerical precision issues. TF32 is enabled by default in PyTorch versions 1.7 to 1.11 (see here). To disable it, you can use

torch.backends.cuda.matmul.allow_tf32 = False

torchrbf will issue a warning if TF32 is enabled.

Usage

Here is a simple example for interpolating 3D data in a 2D domain:

import torch
import matplotlib.pyplot as plt
from torchrbf import RBFInterpolator

y = torch.rand(100, 2) # Data coordinates
d = torch.rand(100, 3) # Data vectors at each point

interpolator = RBFInterpolator(y, d, smoothing=1.0, kernel='thin_plate_spline')

# Query coordinates (100x100 grid of points)
x = torch.linspace(0, 1, 100)
y = torch.linspace(0, 1, 100)
grid_points = torch.meshgrid(x, y, indexing='ij')
grid_points = torch.stack(grid_points, dim=-1).reshape(-1, 2)

# Query RBF on grid points
interp_vals = interpolator(grid_points)

# Plot the interpolated values in 2D
plt.scatter(grid_points[:, 0], grid_points[:, 1], c=interp_vals[:, 0])
plt.title('Interpolated values in 2D')
plt.show()

Performance versus SciPy

Since the module is implemented in PyTorch, it benefits from GPU acceleration. For larger interpolation problems, torchrbf is significantly faster than SciPy's implementation (+100x faster on a RTX 3090):

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchrbf-0.0.1.tar.gz (9.0 kB view hashes)

Uploaded Source

Built Distribution

torchrbf-0.0.1-py3-none-any.whl (12.9 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page