Quartic B-Spline extension: CPU and CUDA extension for quartic bspline based potential functions.
Project description
B-Spline extension
This package implements quartic (midpoint cardinal) b-spline potential functions. By such a potential function we mean a parameter-dependent function $\rho(\cdot, \gamma)$ with
$$ \rho(x, \gamma) = \sum_{\nu = 1}^{N}\gamma_{\nu}M_{4}(\frac{x - t_{\nu}}{s}), $$
where $\lbrace t_\nu\rbrace_{\nu=1}^N$ is an equidistant partition of the Interval $I=[a, b]$, $s>0$ is a scaling parameter and $M_4$ refers to the central quartic midpoint cardinal b-spline (see [1]).
Table of contents
Features
- CUDA and CPU kernels for forward and backward step
- Custom autograd functions based on these kernels to compute in particular the gradients of the potential w.r.t. to input and parameters.
- Feature-wise implementation:
- The potential function is designed for input tensors of shape [bs, f, w, h].
- For a weight tensor of shape [f, N], the potential w.r.t. to the weights [$\nu$, :] is applied to $f$-th channel of the input tensor.
Build
For both building processes a (virtual) Python environment is required.
CUDA extension
Within the Python environment install the packages torch and setuptools. Then,
to build/compile the CUDA extension, from the top directory of this repository execute
python setup.py install
Alternatively call make install.
Python wheel
Activate the Python environment and install the packages setuptools, build and
wheel. Then from the top directory of this repository run
python -m build --wheel --outdir artefacts
Alternatively call make build.
The generate Python wheel is stored in subdirectory artefacts.
Installation
To install the package as a Python wheel simply run
pip install bspline_cuda_extension-0.2.0-cp311-cp311-linux_x86_64.whl
Usage
from quartic_bspline_extension.functions import QuarticBSplineFunction
box_lower = -3.0
box_upper = 3.0
num_centers = 77
centers = torch.linspace(box_lower, box_upper, num_centers)
weights_1 = torch.log(1 + centers ** 2)
weights_2 = torch.abs(centers)
weights = torch.stack([weights_1, weights_2], dim=0)
scale = (box_upper - box_lower) / (num_centers - 1)
f = 2
t = torch.stack([torch.linspace(box_lower, box_upper, 111)
for _ in range(0, f)]).unsqueeze(dim=1).unsqueeze(dim=0)
centers = centers.to(device=device, dtype=dtype)
weights = weights.to(device=device, dtype=dtype)
t = t.to(device=device, dtype=dtype)
t.requires_grad_(True)
y, _ = QuarticBSplineFunction.apply(t, weights, centers, scale)
dy_dt = torch.autograd.grad(inputs=t, outputs=torch.sum(y))[0]
Contributing
- Fork the repository
- Create a feature branch
- Submit a pull request
References
[1] Schoenberg, Isaac J, 1973. Cardinal spline interpolation. SIAM.
License
MIT License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file quartic_bspline_extension-0.3.0.tar.gz.
File metadata
- Download URL: quartic_bspline_extension-0.3.0.tar.gz
- Upload date:
- Size: 10.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
be5c7546d69f58c704429a2d9a2059aeef764a43780b1b442d45ea4b50cb0b6f
|
|
| MD5 |
1c5d8a362ca99bcd8dda1cd02e7932f2
|
|
| BLAKE2b-256 |
a8fbf0e6fa2d0c304c5bfa5e2eeed28cff9fae5d0579d77b3aaa184cca3c0a11
|
File details
Details for the file quartic_bspline_extension-0.3.0-py3-none-any.whl.
File metadata
- Download URL: quartic_bspline_extension-0.3.0-py3-none-any.whl
- Upload date:
- Size: 7.9 MB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
18e6fd3a8c35ea7d867ab798c97edc10485d14b06e75135073d3abebd246a45b
|
|
| MD5 |
48da276b22d7dec4847eca2c1543b158
|
|
| BLAKE2b-256 |
094cf7e9ca0b6b15aa9158a27923d200218d0140e50cb62e8dc7d509709bce2d
|