Skip to main content

Fast, Conditioned KAN

Project description

KANditioned: Fast, Generalizable Training of KANs via Lookup Interpolation

Training is accelerated by orders of magnitude through exploiting the structure of linear (C⁰) spline with uniformly spaced control points, where spline(x) can be calculated as a linear interpolation between the two nearest control points. This is in constrast with the typical summation often seen in B-spline, reducing the amount of computation required and enabling effectively sublinear scaling across the control points dimension.

Install

pip install kanditioned

Usage

It is highly highly recommended to use this layer with torch.compile, which will provide very significant speedups, in addition to a normalization layer before each KANLayer.

from kanditioned.kan_layer import KANLayer

layer = KANLayer(in_features=3, out_features=3, init="random_normal", num_control_points=8)

layer.visualize_all_mappings(save_path="kan_mappings.png")

Args:

in_features (int) – size of each input sample
out_features (int) – size of each output sample
init (str) - initialization method:
    "random_normal": Slope of each spline is drawn from a normal distribution and normalized so that each "neuron" has unit "weight" norm.
    "identity": Identity mapping (requires in_features == out_features). At initialization, the layer's output is the same as the inputs.
    "zero": All splines are init zero.
num_control_points (int): Number of uniformly spaced control points per input feature. Defaults to 32.
spline_width (float): Width of the spline's domain [-spline_width / 2, spline_width / 2]. Defaults to 4.0.

Methods:

visualize_all_mappings(save_path=path[optional]) - this will plot out the shape of each spline and its corresponding input and output feature

How This Works

This implementation of KAN uses a linear (C⁰) spline, with uniformly spaced control points (see Figure 1 and Equation 1).

Figure 1. Linear B-spline example:
Linear B-spline example

Equation 1. B-spline formula: B-spline Formula

Roadmap

  • Update package with cleaned up, efficient Discrete Cosine Transform and parallel scan (prefix sum) reparameterizations. Both provide isotropic κ ~ O(1) conditioned discrete second difference penalty, as opposed to κ ~ O(N^4) conditioning for naive B-spline parameterization. This only matters if you care about regularization.
  • Proper baselines against MLP and various other KAN implementations on backward and forward passes
  • Add in feature-major variant
  • Add optimized Triton kernel
  • Clean up writing

LICENSE

This project is licensed under the Apache License 2.0.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

kanditioned-1.0.1.tar.gz (8.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

kanditioned-1.0.1-py3-none-any.whl (9.3 kB view details)

Uploaded Python 3

File details

Details for the file kanditioned-1.0.1.tar.gz.

File metadata

  • Download URL: kanditioned-1.0.1.tar.gz
  • Upload date:
  • Size: 8.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.11

File hashes

Hashes for kanditioned-1.0.1.tar.gz
Algorithm Hash digest
SHA256 a9dfdfdb9d04ba79141dcf6e58b0e77f9a1e04efcce6758afd3aa191db45c632
MD5 bdf2dad02498d73b230240dfa093fe02
BLAKE2b-256 415195ee942a2e78620bff473a5c27d7fa5d3c1e13399a1fb7fbec0f4b036b94

See more details on using hashes here.

File details

Details for the file kanditioned-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: kanditioned-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 9.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.11

File hashes

Hashes for kanditioned-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 bf331852855862d443cd42005f49bbc1153cc58da39236f51894a7351fbdfba9
MD5 09b9ee88cc4e6587444cc02ab342b3d5
BLAKE2b-256 a1cba2d65f9c087405610bdc476fc878448bc5431759509c1251206723f097ea

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page