Skip to main content

SplineTorch is a Python package for fitting splines in PyTorch.

Project description

tests

SplineTorch

SplineTorch is a Python package for fitting splines in PyTorch. Spline with enforced and/or penalized constraints are supported.

Installation

pip install splinetorch

Features

  • Univariate and multivariate B-spline regression
  • Binary and categorical classification
  • Probability density function (PDF) fitting
  • Constraint support:
    • Derivative constraints (e.g., monotonicity, convexity)
    • Point constraints
    • PDF constraints (non-negativity and integration to 1)

Examples

Simple Univariate Regression

import torch
from splinetorch.b_spline import BSpline

x = torch.linspace(0, 1, 100).view(-1, 1)
y = torch.sin(2 * torch.pi * x) + torch.randn_like(x) * 0.1
spline = BSpline(x=x, y=y)
spline.fit(x, y)
y_pred = spline.predict(x)
spline.plot_fit(x, y)

Binary Classification with monotonely increasing logit

import torch
from splinetorch.b_spline import BSpline

x = torch.linspace(-3, 3, 100).view(-1, 1)
y = (torch.sin(2.2 * torch.pi * x) > 0).float()

# Force monotone increasing logit
derivative_constraints = {1: {'>': torch.tensor(0.0)}}  # 1st derivative > 0
spline = BSpline(x=x, y=y, output_type="binary")
spline.fit(x, y, derivative_constraints=derivative_constraints)
probabilities = spline.predict(x)
spline.plot_fit(x, y)

Multivariate Regression with Point Constraints

import torch
from splinetorch.b_spline import BSpline

x = torch.rand(1000, 2)  # 2D input
y = torch.sum(2*x**2, dim=1).view(-1, 1) - 2
# Constrain degree 0 derivative (the function itself). In the example we want both x values of 0 to return 0.
point_constraints = { 0: {'=': (torch.tensor([[0.0, 0.0]]), torch.tensor([[0.0]]))} }
spline = BSpline(x=x, y=y)
spline.fit(x, y, point_constraints=point_constraints)
y_pred = spline.predict(x)
# Verify the constraint at (0,0)
test_point = torch.tensor([[0.0, 0.0]])
predicted_value = spline.predict(test_point)
print(f"Value at (0,0): {predicted_value.item():.6f}")

PDF Fitting with Constraints (encourage integration to 1 within the interval)

import torch
from splinetorch.b_spline import BSpline

x = torch.linspace(0, 1, 1000).view(-1, 1)
y = torch.exp(-((x - 0.5)**2) / 0.1).view(-1, 1)  # Shape: (n, 1)
spline = BSpline(x=x, y=y)
spline.fit(x, y, pdf_constraint=True)  # Non-negative and integrates to 1
density = spline.predict(x)
spline.plot_fit(x, y)

Three class classification

import torch
from splinetorch.b_spline import BSpline
import numpy as np
import matplotlib.pyplot as plt

x = torch.rand(300, 2)  # 2D input
distances = torch.sum(x**2, dim=1)
y = torch.zeros(len(x), dtype=torch.long)
y[distances < 0.5] = 0
y[(distances >= 0.5) & (distances < 1.0)] = 1
y[distances >= 1.0] = 2

# Fit the spline
spline = BSpline(x=x, y=y, output_type="categorical")
spline.fit(x, y)

# Create a grid of points and get predictions
grid_size = 100
x1_min, x1_max = 0, 1
x2_min, x2_max = 0, 1
x1_grid, x2_grid = np.meshgrid(np.linspace(x1_min, x1_max, grid_size), np.linspace(x2_min, x2_max, grid_size))
grid_points = torch.tensor(np.column_stack([x1_grid.ravel(), x2_grid.ravel()]), dtype=torch.float32)
probs = spline.predict(grid_points)
predictions = probs.argmax(dim=1).numpy()

# Plot decision boundaries
fig, axes = plt.subplots(2, 2, figsize=(15, 15))
axes = axes.ravel()
predictions_2d = predictions.reshape(grid_size, grid_size)
im = axes[0].contourf(x1_grid, x2_grid, predictions_2d, levels=np.arange(4)-0.5, cmap='viridis')
scatter = axes[0].scatter(x[:, 0], x[:, 1], c=y, cmap='viridis', edgecolor='black', s=50)
axes[0].set_title('Decision Boundaries')
axes[0].set_xlabel('X1')
axes[0].set_ylabel('X2')
plt.colorbar(im, ax=axes[0], label='Class')

# Plot probability landscapes for each class
for i in range(3):
    class_probs = probs[:, i].numpy().reshape(grid_size, grid_size)
    im = axes[i+1].contourf(x1_grid, x2_grid, class_probs, levels=20, cmap='RdYlBu')
    # Plot points belonging to this class
    class_mask = (y == i)
    axes[i+1].scatter(x[class_mask, 0], x[class_mask, 1], color='black', 
                     edgecolor='white', s=50, label=f'Class {i} points')
    axes[i+1].set_title(f'Class {i} Probability')
    axes[i+1].set_xlabel('X1')
    axes[i+1].set_ylabel('X2')
    plt.colorbar(im, ax=axes[i+1], label='Probability')
    axes[i+1].legend()

plt.tight_layout()
plt.show()

# Print classification accuracy
class_probs = spline.predict(x)
_, predictions = class_probs.max(dim=1)
accuracy = (predictions == y).float().mean()
print(f"Classification accuracy: {accuracy:.3f}")

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

splinetorch-0.0.2.tar.gz (17.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

splinetorch-0.0.2-py3-none-any.whl (11.9 kB view details)

Uploaded Python 3

File details

Details for the file splinetorch-0.0.2.tar.gz.

File metadata

  • Download URL: splinetorch-0.0.2.tar.gz
  • Upload date:
  • Size: 17.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.11.5

File hashes

Hashes for splinetorch-0.0.2.tar.gz
Algorithm Hash digest
SHA256 f7fdb033732d475a8f09c39440e6bc9f67e62b50d1af3de9af1e4ff1bc9b2bf2
MD5 2cfaf8c7dc7aa39cfe15bf1b0dbc23d8
BLAKE2b-256 eda1e60246d649d476ff46f56cb2060d8f33190786ba36989d9c9f2ebd4bfb0c

See more details on using hashes here.

File details

Details for the file splinetorch-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: splinetorch-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 11.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.11.5

File hashes

Hashes for splinetorch-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 501ece9530a99f06d70cf1d90e420d03115b0bbab5e43d05a931b1c5173b0ead
MD5 71b2a8e657f788c09f67b439421cc9d4
BLAKE2b-256 2285b65d22e1fb606f7864ff5a1336191ddfc3b48293db07614c385ff792f169

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page