Skip to main content

Radial Basis Function Interpolation in PyTorch

Project description

torchrbf: Radial Basis Function Interpolation in PyTorch

This is a PyTorch module for Radial Basis Function (RBF) Interpolation, which is translated from SciPy's implemenation. This implementation benefits from GPU acceleration, making it significantly faster and more suitable for larger interpolation problems.

Installation

pip install torchrbf

The only dependencies are PyTorch and NumPy. If you want to run the tests and benchmarks, you also need SciPy installed.

A note on numerical precision

If you are using TF32, you may experience numerical precision issues. TF32 is enabled by default in PyTorch versions 1.7 to 1.11 (see here). To disable it, you can use

torch.backends.cuda.matmul.allow_tf32 = False

torchrbf will issue a warning if TF32 is enabled.

Usage

Here is a simple example for interpolating 3D data in a 2D domain:

import torch
import matplotlib.pyplot as plt
from torchrbf import RBFInterpolator

y = torch.rand(100, 2) # Data coordinates
d = torch.rand(100, 3) # Data vectors at each point

interpolator = RBFInterpolator(y, d, smoothing=1.0, kernel='thin_plate_spline')

# Query coordinates (100x100 grid of points)
x = torch.linspace(0, 1, 100)
y = torch.linspace(0, 1, 100)
grid_points = torch.meshgrid(x, y, indexing='ij')
grid_points = torch.stack(grid_points, dim=-1).reshape(-1, 2)

# Query RBF on grid points
interp_vals = interpolator(grid_points)

# Plot the interpolated values in 2D
plt.scatter(grid_points[:, 0], grid_points[:, 1], c=interp_vals[:, 0])
plt.title('Interpolated values in 2D')
plt.show()

Performance versus SciPy

Since the module is implemented in PyTorch, it benefits from GPU acceleration. For larger interpolation problems, torchrbf is significantly faster than SciPy's implementation (+100x faster on a RTX 3090):

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchrbf-1.0.0.tar.gz (11.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchrbf-1.0.0-py3-none-any.whl (15.4 kB view details)

Uploaded Python 3

File details

Details for the file torchrbf-1.0.0.tar.gz.

File metadata

  • Download URL: torchrbf-1.0.0.tar.gz
  • Upload date:
  • Size: 11.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.23

File hashes

Hashes for torchrbf-1.0.0.tar.gz
Algorithm Hash digest
SHA256 a4cfd2824f2f8529606bc510003838d84756610fd4a819035fd6006d38fa3582
MD5 903d9f9091cff887a8f8fdcd6d6c10aa
BLAKE2b-256 50b3458db76bfa0fb2acfda5e63fb8c5163484e41bc928e2b15c7e97532950cf

See more details on using hashes here.

File details

Details for the file torchrbf-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: torchrbf-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 15.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.23

File hashes

Hashes for torchrbf-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 cd36b12f5f020853efb9bd397a9d95b11fd2d2f52914fc10a479d669bfcf72d1
MD5 d1621844a2b851244df5b660271f3f68
BLAKE2b-256 14d8328c041699ca65588ac4e8b95f3bfc3bd5e55f47039abce92e1dbf5683dc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page