Skip to main content

Fast, Conditioned KAN

Project description

KANditioned: Fast, Conditioned Training of Kolmogorov-Arnold Networks (KANs) via Lookup Interpolation and Discrete Cosine Transform

Training is accelerated by orders of magnitude through exploiting the structure of the linear (C⁰) B-spline (see Fig. 1) with uniformly spaced control points. Because the intervals are uniform, evaluating spline(x) reduces to a constant-time index calculation, followed by looking up the two relevant control points and linearly interpolating between them. This contrasts with the typical summation over basis functions typically seen in splines, reducing the amount of computation required and enabling effective sublinear scaling across the control points dimension.

Install

pip install kanditioned

Usage

[!IMPORTANT]
It is highly recommended to use this layer with torch.compile, which will provide very significant speedups (Triton kernel is coming sometimes later, but I found torch.compile to provide very satisfactory performance), in addition to a normalization layer before each KANLayer.

from kanditioned.kan_layer import KANLayer

layer = KANLayer(in_features=3, out_features=3, init="random_normal", num_control_points=8)
layer.visualize_all_mappings(save_path="kan_mappings.png")

Arguments

  • in_features (int)
    Size of each input sample.
  • out_features (int)
    Size of each output sample.
  • init (str) – initialization method:
    • "random_normal": Slopes drawn from a normal distribution, then normalized so each “neuron” has unit weight norm.
    • "identity": Identity mapping (requires in_features == out_features). Output initially equals input.
    • "zero": All splines initialized to zero.
  • num_control_points (int, default = 32)
    Number of uniformly spaced control points per input feature.
  • spline_width (float, default = 4.0)
    Domain of the spline: [-spline_width / 2, spline_width / 2]. Beyond that, the spline will linearly extrapolate.
  • impl (str, default = embedding_bag) Note: F.embedding_bag implementation appears to be much faster when used for inference with torch.compile enabled or when used for inference/training without torch.compile. However, F.embedding appears to be somewhat faster than F.embedding_bag when used for training with torch.compile enabled. Experiment with either implementation as necessary to achieve peak performance.
    • "embedding_bag": Implement the layer using F.embedding_bag.
    • "embedding": Implement the layer using F.embedding.

Methods:

visualize_all_mappings(save_path=path[optional]) - this will plot out the shape of each spline and its corresponding input and output feature

Figure

Linear B-spline example

Figure 1. Linear B-spline example (each triangle-like shape is a basis):

Roadmap

  • Use F.embedding_bag
  • Update doc for variant and other new parameters introduced
  • Update package with cleaned up, efficient Discrete Cosine Transform (with rank-2 correction) and parallel scan (prefix sum) parameterizations.
    • Both provide isotropic O(1) condition scaling for the discrete second difference penalty, as opposed to O(N^4) conditioning for the naive B-spline parameterization. This only matters if you care about regularization.
    • May add linearDCT variant first. Although it's O(N^2), it's more parallelized and optimized on GPU for small N since it's essentially a matmul with weight being a DCT matrix
  • Proper baselines against MLP and various other KAN implementations on backward and forward passes
  • Add sorting on indices and unsorting as an option (potentially radix sort, which is common optimization on embedding) to improve computational time through global memory "coalesced" access
  • Add in feature-major input variant
  • May change to either unfold or as_strided (slight performance improvement)
  • Run benchmarks and further optimize memory locality
    • Feature-major input variant versus batch-major input variant
    • Interleaved indices [l1, u1, l2, u2, ...] versus stacked indices [l1, l2, ..., u1, u2, ...]
  • Add optimized Triton kernel
  • Update visualize_all_mappings method to something like .plot with option for plotting everything
  • Research adding Legendre polynomials parameterization
    • Preliminary: does not seem to offer much benefits or have isotropic penalty conditioning
  • Polish writing

Open To Collaborators. Contributions Are Welcomed!

LICENSE

This project is licensed under the Apache License 2.0.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

kanditioned-1.0.3.tar.gz (9.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

kanditioned-1.0.3-py3-none-any.whl (10.7 kB view details)

Uploaded Python 3

File details

Details for the file kanditioned-1.0.3.tar.gz.

File metadata

  • Download URL: kanditioned-1.0.3.tar.gz
  • Upload date:
  • Size: 9.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.11

File hashes

Hashes for kanditioned-1.0.3.tar.gz
Algorithm Hash digest
SHA256 25b645d47a39240d3f1be313add139d579fc2cadbfc872bda7d2b12e2a768d4a
MD5 a5a26a959b07a24da9207990a479ca7d
BLAKE2b-256 44e535ed8a9797fe791db06812463ab89e404b78890393103ac224f7977df40f

See more details on using hashes here.

File details

Details for the file kanditioned-1.0.3-py3-none-any.whl.

File metadata

  • Download URL: kanditioned-1.0.3-py3-none-any.whl
  • Upload date:
  • Size: 10.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.11

File hashes

Hashes for kanditioned-1.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 a0a6dcca94485f3812a6b84f44f9d7b007d3fbb6a82f89202322d72bafdf6dce
MD5 6e59b32d527795061769fa83a5d35164
BLAKE2b-256 d4a651d567e83e91b8ace8870d5fa30523daacff37b1890c0070b339f5fb82b1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page