SplineTorch is a Python package for fitting splines in PyTorch.
Project description
SplineTorch
SplineTorch is a Python package for fitting splines in PyTorch. Spline with enforced and/or penalized constraints are supported.
Installation
pip install splinetorch
Features
- Univariate and multivariate B-spline regression
- Binary and categorical classification
- Probability density function (PDF) fitting
- Constraint support:
- Derivative constraints (e.g., monotonicity, convexity)
- Point constraints
- PDF constraints (non-negativity and integration to 1)
Examples
Simple Univariate Regression
import torch
from splinetorch.b_spline import BSpline
x = torch.linspace(0, 1, 100).view(-1, 1)
y = torch.sin(2 * torch.pi * x) + torch.randn_like(x) * 0.1
spline = BSpline(x=x, y=y)
spline.fit(x, y)
y_pred = spline.predict(x)
spline.plot_fit(x, y)
Binary Classification with monotonely increasing logit
import torch
from splinetorch.b_spline import BSpline
x = torch.linspace(-3, 3, 100).view(-1, 1)
y = (torch.sin(2.2 * torch.pi * x) > 0).float()
# Force monotone increasing logit
derivative_constraints = {1: {'>': torch.tensor(0.0)}} # 1st derivative > 0
spline = BSpline(x=x, y=y, output_type="binary")
spline.fit(x, y, derivative_constraints=derivative_constraints)
probabilities = spline.predict(x)
spline.plot_fit(x, y)
Multivariate Regression with Point Constraints
import torch
from splinetorch.b_spline import BSpline
x = torch.rand(1000, 2) # 2D input
y = torch.sum(2*x**2, dim=1).view(-1, 1) - 2
# Constrain degree 0 derivative (the function itself). In the example we want both x values of 0 to return 0.
point_constraints = { 0: {'=': (torch.tensor([[0.0, 0.0]]), torch.tensor([[0.0]]))} }
spline = BSpline(x=x, y=y)
spline.fit(x, y, point_constraints=point_constraints)
y_pred = spline.predict(x)
# Verify the constraint at (0,0)
test_point = torch.tensor([[0.0, 0.0]])
predicted_value = spline.predict(test_point)
print(f"Value at (0,0): {predicted_value.item():.6f}")
PDF Fitting with Constraints (encourage integration to 1 within the interval)
import torch
from splinetorch.b_spline import BSpline
x = torch.linspace(0, 1, 1000).view(-1, 1)
y = torch.exp(-((x - 0.5)**2) / 0.1).view(-1, 1) # Shape: (n, 1)
spline = BSpline(x=x, y=y)
spline.fit(x, y, pdf_constraint=True) # Non-negative and integrates to 1
density = spline.predict(x)
spline.plot_fit(x, y)
Three class classification
import torch
from splinetorch.b_spline import BSpline
import numpy as np
import matplotlib.pyplot as plt
x = torch.rand(300, 2) # 2D input
distances = torch.sum(x**2, dim=1)
y = torch.zeros(len(x), dtype=torch.long)
y[distances < 0.5] = 0
y[(distances >= 0.5) & (distances < 1.0)] = 1
y[distances >= 1.0] = 2
# Fit the spline
spline = BSpline(x=x, y=y, output_type="categorical")
spline.fit(x, y)
# Create a grid of points and get predictions
grid_size = 100
x1_min, x1_max = 0, 1
x2_min, x2_max = 0, 1
x1_grid, x2_grid = np.meshgrid(np.linspace(x1_min, x1_max, grid_size), np.linspace(x2_min, x2_max, grid_size))
grid_points = torch.tensor(np.column_stack([x1_grid.ravel(), x2_grid.ravel()]), dtype=torch.float32)
probs = spline.predict(grid_points)
predictions = probs.argmax(dim=1).numpy()
# Plot decision boundaries
fig, axes = plt.subplots(2, 2, figsize=(15, 15))
axes = axes.ravel()
predictions_2d = predictions.reshape(grid_size, grid_size)
im = axes[0].contourf(x1_grid, x2_grid, predictions_2d, levels=np.arange(4)-0.5, cmap='viridis')
scatter = axes[0].scatter(x[:, 0], x[:, 1], c=y, cmap='viridis', edgecolor='black', s=50)
axes[0].set_title('Decision Boundaries')
axes[0].set_xlabel('X1')
axes[0].set_ylabel('X2')
plt.colorbar(im, ax=axes[0], label='Class')
# Plot probability landscapes for each class
for i in range(3):
class_probs = probs[:, i].numpy().reshape(grid_size, grid_size)
im = axes[i+1].contourf(x1_grid, x2_grid, class_probs, levels=20, cmap='RdYlBu')
# Plot points belonging to this class
class_mask = (y == i)
axes[i+1].scatter(x[class_mask, 0], x[class_mask, 1], color='black',
edgecolor='white', s=50, label=f'Class {i} points')
axes[i+1].set_title(f'Class {i} Probability')
axes[i+1].set_xlabel('X1')
axes[i+1].set_ylabel('X2')
plt.colorbar(im, ax=axes[i+1], label='Probability')
axes[i+1].legend()
plt.tight_layout()
plt.show()
# Print classification accuracy
class_probs = spline.predict(x)
_, predictions = class_probs.max(dim=1)
accuracy = (predictions == y).float().mean()
print(f"Classification accuracy: {accuracy:.3f}")
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
splinetorch-0.0.1.tar.gz
(17.1 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file splinetorch-0.0.1.tar.gz.
File metadata
- Download URL: splinetorch-0.0.1.tar.gz
- Upload date:
- Size: 17.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.11.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
79b82033ed9695847030663981c588ea5330200f98fc0a39bbca3d3836927e58
|
|
| MD5 |
8aa54813bf849e6f45ee881b796fbb27
|
|
| BLAKE2b-256 |
bf49e0541764453780130c29d0ca2d4dc01266e2908ef5faa55509367ab98af2
|
File details
Details for the file splinetorch-0.0.1-py3-none-any.whl.
File metadata
- Download URL: splinetorch-0.0.1-py3-none-any.whl
- Upload date:
- Size: 11.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.11.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ba2025acd42618f465f43c830ec32c6b1e8a4cc028fae30c1b21d6fd9227f3d3
|
|
| MD5 |
4e1b33129c95a975a0e2176f122e27d8
|
|
| BLAKE2b-256 |
ba576122bbe9e650188e11c10d78cc08adff7e32ed05f0033da1f6ef2d64f079
|