Skip to main content

Quartic B-Spline extension: CPU and CUDA extension for quartic bspline based potential functions.

Project description

B-Spline extension

This package implements quartic (midpoint cardinal) b-spline potential functions. By such a potential function we mean a parameter-dependent function $\rho(\cdot, \gamma)$ with

$$ \rho(x, \gamma) = \sum_{\nu = 1}^{N}\gamma_{\nu}M_{4}(\frac{x - t_{\nu}}{s}), $$

where $\lbrace t_\nu\rbrace_{\nu=1}^N$ is an equidistant partition of the Interval $I=[a, b]$, $s>0$ is a scaling parameter and $M_4$ refers to the central quartic midpoint cardinal b-spline (see [1]).

Table of contents

Features

  • CUDA and CPU kernels for forward and backward step
  • Custom autograd functions based on these kernels to compute in particular the gradients of the potential w.r.t. to input and parameters.
  • Feature-wise implementation:
    • The potential function is designed for input tensors of shape [bs, f, w, h].
    • For a weight tensor of shape [f, N], the potential w.r.t. to the weights [$\nu$, :] is applied to $f$-th channel of the input tensor.

Build

For both building processes a (virtual) Python environment is required.

CUDA extension

Within the Python environment install the packages torch and setuptools. Then, to build/compile the CUDA extension, from the top directory of this repository execute

    python setup.py install

Alternatively call make install.

Python wheel

Activate the Python environment and install the packages setuptools, build and wheel. Then from the top directory of this repository run

    python -m build --wheel --outdir artefacts

Alternatively call make build.

The generate Python wheel is stored in subdirectory artefacts.

Installation

To install the package as a Python wheel simply run

    pip install bspline_cuda_extension-0.2.0-cp311-cp311-linux_x86_64.whl

Usage

from quartic_bspline_extension.functions import QuarticBSplineFunction

box_lower = -3.0
box_upper = 3.0
num_centers = 77
centers = torch.linspace(box_lower, box_upper, num_centers)
weights_1 = torch.log(1 + centers ** 2)
weights_2 = torch.abs(centers)
weights = torch.stack([weights_1, weights_2], dim=0)
scale = (box_upper - box_lower) / (num_centers - 1)

f = 2
t = torch.stack([torch.linspace(box_lower, box_upper, 111)
                    for _ in range(0, f)]).unsqueeze(dim=1).unsqueeze(dim=0)

centers = centers.to(device=device, dtype=dtype)
weights = weights.to(device=device, dtype=dtype)
t = t.to(device=device, dtype=dtype)
t.requires_grad_(True)

y, _ = QuarticBSplineFunction.apply(t, weights, centers, scale)
dy_dt = torch.autograd.grad(inputs=t, outputs=torch.sum(y))[0]

Contributing

  1. Fork the repository
  2. Create a feature branch
  3. Submit a pull request

References

[1] Schoenberg, Isaac J, 1973. Cardinal spline interpolation. SIAM.

License

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

quartic_bspline_extension-0.3.0.tar.gz (10.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

quartic_bspline_extension-0.3.0-py3-none-any.whl (7.9 MB view details)

Uploaded Python 3

File details

Details for the file quartic_bspline_extension-0.3.0.tar.gz.

File metadata

File hashes

Hashes for quartic_bspline_extension-0.3.0.tar.gz
Algorithm Hash digest
SHA256 be5c7546d69f58c704429a2d9a2059aeef764a43780b1b442d45ea4b50cb0b6f
MD5 1c5d8a362ca99bcd8dda1cd02e7932f2
BLAKE2b-256 a8fbf0e6fa2d0c304c5bfa5e2eeed28cff9fae5d0579d77b3aaa184cca3c0a11

See more details on using hashes here.

File details

Details for the file quartic_bspline_extension-0.3.0-py3-none-any.whl.

File metadata

File hashes

Hashes for quartic_bspline_extension-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 18e6fd3a8c35ea7d867ab798c97edc10485d14b06e75135073d3abebd246a45b
MD5 48da276b22d7dec4847eca2c1543b158
BLAKE2b-256 094cf7e9ca0b6b15aa9158a27923d200218d0140e50cb62e8dc7d509709bce2d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page