Fast, Conditioned KAN
Project description
KANditioned: Fast, Generalizable Training of KANs via Lookup Interpolation
Training is accelerated by orders of magnitude through exploiting the structure of linear (C⁰) spline with uniformly spaced control points, where spline(x) can be calculated as a linear interpolation between the two nearest control points. This is in constrast with the typical summation often seen in B-spline, reducing the amount of computation required and enabling effectively sublinear scaling across the control points dimension.
Install
pip install kanditioned
Usage
It is highly highly recommended to use this layer with torch.compile, which will provide very significant speedups, in addition to a normalization layer before each KANLayer.
from kanditioned.kan_layer import KANLayer
layer = KANLayer(in_features=3, out_features=3, init="random_normal", num_control_points=8)
layer.visualize_all_mappings(save_path="kan_mappings.png")
Args:
in_features (int) – size of each input sample
out_features (int) – size of each output sample
init (str) - initialization method:
"random_normal": Slope of each spline is drawn from a normal distribution and normalized so that each "neuron" has unit "weight" norm.
"identity": Identity mapping (requires in_features == out_features). At initialization, the layer's output is the same as the inputs.
"zero": All splines are init zero.
num_control_points (int): Number of uniformly spaced control points per input feature. Defaults to 32.
spline_width (float): Width of the spline's domain [-spline_width / 2, spline_width / 2]. Defaults to 4.0.
Methods:
visualize_all_mappings(save_path=path[optional]) - this will plot out the shape of each spline and its corresponding input and output feature
How This Works
This implementation of KAN uses a linear (C⁰) spline, with uniformly spaced control points (see Figure 1 and Equation 1).
Figure 1. Linear B-spline example:
Equation 1. B-spline formula:
Roadmap
- Update package with cleaned up, efficient Discrete Cosine Transform and parallel scan (prefix sum) reparameterizations. Both provide isotropic κ ~ O(1) conditioned discrete second difference penalty, as opposed to κ ~ O(N^4) conditioning for naive B-spline parameterization. This only matters if you care about regularization.
- Proper baselines against MLP and various other KAN implementations on backward and forward passes
- Add in feature-major variant
- Add optimized Triton kernel
- Clean up writing
LICENSE
This project is licensed under the Apache License 2.0.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file kanditioned-1.0.1.tar.gz.
File metadata
- Download URL: kanditioned-1.0.1.tar.gz
- Upload date:
- Size: 8.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a9dfdfdb9d04ba79141dcf6e58b0e77f9a1e04efcce6758afd3aa191db45c632
|
|
| MD5 |
bdf2dad02498d73b230240dfa093fe02
|
|
| BLAKE2b-256 |
415195ee942a2e78620bff473a5c27d7fa5d3c1e13399a1fb7fbec0f4b036b94
|
File details
Details for the file kanditioned-1.0.1-py3-none-any.whl.
File metadata
- Download URL: kanditioned-1.0.1-py3-none-any.whl
- Upload date:
- Size: 9.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bf331852855862d443cd42005f49bbc1153cc58da39236f51894a7351fbdfba9
|
|
| MD5 |
09b9ee88cc4e6587444cc02ab342b3d5
|
|
| BLAKE2b-256 |
a1cba2d65f9c087405610bdc476fc878448bc5431759509c1251206723f097ea
|