Skip to main content

Fast, Conditioned KAN

Reason this release was yanked:

Bug

Project description

KANditioned: Fast, Conditioned Training of KANs via Lookup Interpolation and Discrete Cosine Transform

Install

pip install kanditioned

Usage

It is highly recommended to use this layer with torch.compile, which will provide very significant speedups, in addition to a normalization layer before each KANLayer.

from kanditioned.kan_layer import KANLayer

layer = KANLayer(in_features=3, out_features=3, init="random_normal", num_control_points=8)

layer.visualize_all_mappings(save_path="kan_mappings.png")

Args:

in_features (int) – size of each input sample
out_features (int) – size of each output sample
init (str) - initialization method:
    "random_normal": Slope of each spline is drawn from a normal distribution and normalized so that each "neuron" has unit "weight" norm.
    "identity": Identity mapping (requires in_features == out_features). At initialization, the layer's output is the same as the inputs.
    "zero": All splines are init zero.
num_control_points (int): Number of uniformly spaced control points per input feature. Defaults to 32.
spline_width (float): Width of the spline's domain [-spline_width / 2, spline_width / 2]. Defaults to 4.0.

Methods:

visualize_all_mappings(save_path=path[optional]) - this will plot out the shape of each spline and its corresponding input and output feature

How This Works

Training is accelerated by orders of magnitude through exploiting the structure of the linear (C⁰) B-spline (see Fig. 1) with uniformly spaced control points. Because the intervals are uniform, evaluating spline(x) reduces to a constant-time index calculation, followed by looking up the two relevant control points and linearly interpolating between them. This contrasts with the typical summation over basis functions typically seen in splines, reducing the amount of computation required and enabling effective sublinear scaling across the control points dimension.

Linear B-spline example

Figure 1. Linear B-spline example (each triangle-like shape is a basis):

Roadmap

  • Update package with cleaned up, efficient Discrete Cosine Transform and parallel scan (prefix sum) parameterizations.
    • Both provide isotropic O(1) condition scaling for the discrete second difference penalty, as opposed to O(N^4) conditioning for the naive B-spline parameterization. This only matters if you care about regularization.
    • May add linearDCT variant first. Although it's O(N^2), it's more parallelized and optimized on GPU for small N since it's essentially a matmul with weight being a DCT matrix
  • Proper baselines against MLP and various other KAN implementations on backward and forward passes
  • Add in feature-major variant
  • Add optimized Triton kernel
  • Polish writing

LICENSE

This project is licensed under the Apache License 2.0.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

kanditioned-1.0.2.tar.gz (8.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

kanditioned-1.0.2-py3-none-any.whl (9.3 kB view details)

Uploaded Python 3

File details

Details for the file kanditioned-1.0.2.tar.gz.

File metadata

  • Download URL: kanditioned-1.0.2.tar.gz
  • Upload date:
  • Size: 8.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.11

File hashes

Hashes for kanditioned-1.0.2.tar.gz
Algorithm Hash digest
SHA256 e117b722528613e5548a80d7489cba7039d69948d2daa69832a4766b17e06708
MD5 4f1df619c36cea47dc3162a7423db6c9
BLAKE2b-256 31c4d190b67093432b23886ffe9d8c3d8056b14e96b8bbdb05ae653fae7535ea

See more details on using hashes here.

File details

Details for the file kanditioned-1.0.2-py3-none-any.whl.

File metadata

  • Download URL: kanditioned-1.0.2-py3-none-any.whl
  • Upload date:
  • Size: 9.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.11

File hashes

Hashes for kanditioned-1.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 a725ba653b792fab9a686b1b251a7638a856751ef3cf47aeb3189eb6dda44dba
MD5 9a324a714fb25f9c189bc382e5b7cb44
BLAKE2b-256 0d085f055de6f105ce73332791ea939009c9cfaf325ee261607c8827557820c6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page