Skip to main content

Spiking Rational Kolmogorov-Arnold Network

Project description

๐Ÿง  SRKAN

Spiking Recurrent Kolmogorov-Arnold Networks

A bio-inspired neural architecture that fuses the temporal expressiveness of Spiking Neural Networks with the function-learning power of Kolmogorov-Arnold Networks.

Installation | Quick Start | Architecture | Benchmarks | API Reference

Why SRKAN?

  • Standard SNNs suffer from surrogate gradient bias โ€” the mismatch between forward spikes and backward gradients causes training instability in deep networks.
  • Standard KANs have no temporal memory โ€” they process each timestep independently, making them blind to spike timing patterns.

SRKAN solves both problems simultaneously:

  • Saltation gradient โ€” a biologically-motivated surrogate that tracks sub-threshold membrane dynamics, reducing gradient mismatch
  • Dendritic KAN gate โ€” replaces linear synaptic weights with learnable B-spline functions, letting each synapse learn its own nonlinear transfer curve
  • AdaptiveLIF + EMA deadband homeostasis โ€” prevents firing-rate collapse in deep stacks (the key training stability fix)
  • Chunked parallel scan โ€” O(โˆšT) memory complexity for long spike trains, enabling 28ms inference on standard GPU

Installation

pip install srkan

Requirements: Python โ‰ฅ 3.9, PyTorch โ‰ฅ 2.0, pykan

Quick Start

import torch
from srkan import SRKAN

# Build model โ€” 700 input channels (SHD cochlear), 20 classes
model = SRKAN(
    n_in      = 700,
    n_hidden  = 256,
    n_out     = 20,
    n_layers  = 4,
    T         = 100,       # timesteps
    grid      = 5,         # B-spline grid resolution
    chunk_size= 10         # chunked scan window
)

# Input: [Batch, Time, Channels] binary spike tensor
x = torch.randint(0, 2, (32, 100, 700)).float()

# Classification logits via membrane voltage readout
logits = model.readout_membrane(x).mean(dim=1)   # [B, n_out]
loss   = torch.nn.functional.cross_entropy(logits, targets)

Note: Use model.readout_membrane(x).mean(dim=1) for classification โ€” not model(x). The membrane voltage is differentiable everywhere and preserves full sub-threshold timing signal.

Architecture

Input Spikes [B, T, n_in]
        โ”‚
        โ–ผ
  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
  โ”‚   Pre-projection  (Linear)      โ”‚  784 โ†’ 64 (reduces KAN OOM risk)
  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
        โ”‚
        โ–ผ
  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”  โ”
  โ”‚  Modal Synapse  (B-spline KAN)  โ”‚  โ”‚
  โ”‚  Dendritic Gate (gated signal)  โ”‚  โ”‚  ร— n_layers
  โ”‚  AdaptiveLIF   (spike + EMA)    โ”‚  โ”‚
  โ”‚  Chunked Scan  (O(โˆšT) memory)   โ”‚  โ”‚
  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜  โ”˜
        โ”‚
        โ–ผ
  Readout Layer โ†’ Membrane Voltage [B, T, n_out]
        โ”‚  .mean(dim=1)
        โ–ผ
  Logits [B, n_out]

Key Components

Component Role Innovation
ModalSynapse B-spline synaptic transfer Replaces linear weights with learnable nonlinear curves
DendriticGate Gated signal mixing Separates fast/slow dendritic compartments
SaltationGradient Surrogate for backprop Tracks sub-threshold dynamics, reduces bias vs. rectangle surrogate
AdaptiveLIF Leaky Integrate-and-Fire EMA-based threshold adaptation with deadband to prevent collapse
ChunkedScan Temporal recurrence Parallel scan in O(โˆšT) memory vs. O(T) for naive RNN

Benchmarks

Evaluated on the Spiking Heidelberg Digits (SHD) dataset โ€” 700-channel cochlear spike recordings, 20 spoken digit classes, T=100 timesteps.

Accuracy

Model Test Accuracy Parameters
KAN (feed-forward) 18.9% 132K
MLP 49.6% 18.2M
RNN 70.2% 520K
SRKAN (ours) 83.4% 34.8K

SRKAN achieves +33.8 points over MLP with 500ร— fewer parameters.

Training Stability (Homeostasis Ablation)

Variant Best Acc Final Acc Behavior
Naive homeostasis 50.6% 31.6% 19-pt collapse โŒ
EMA + deadband (SRKAN) 52.8% 48.8% Stable โœ…

The EMA deadband homeostasis is the critical stability fix โ€” without it, deep SNN stacks collapse during training regardless of surrogate choice.

Inference Speed

Model Latency (ms/sample)
KAN 0.31
RNN 0.09
MLP 0.06
SRKAN 0.28

API Reference

SRKAN(n_in, n_hidden, n_out, n_layers, T, grid, chunk_size)

Parameter Type Default Description
n_in int โ€” Input channels (e.g. 700 for SHD)
n_hidden int 256 Hidden layer width
n_out int โ€” Number of output classes
n_layers int 4 Number of SRKAN blocks
T int 100 Number of timesteps
grid int 5 B-spline grid resolution
chunk_size int 10 Chunked scan window (โˆšT recommended)

Methods

model.forward(x)             # โ†’ binary spikes [B, T, n_out]  โ€” hidden layers only
model.readout_membrane(x)    # โ†’ membrane voltage [B, T, n_out] โ€” use for classification

Training Example (SHD)

import torch
from srkan import SRKAN

model    = SRKAN(n_in=700, n_hidden=256, n_out=20, n_layers=4, T=100)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for spikes, labels in train_loader:           # spikes: [B, T, 700]
    logits = model.readout_membrane(spikes).mean(dim=1)
    loss   = torch.nn.functional.cross_entropy(logits, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Citation

If SRKAN is useful in your research, please cite:

@software{srkan2026,
  title   = {SRKAN: Spiking Recurrent Kolmogorov-Arnold Networks},
  year    = {2026},
  url     = {https://pypi.org/project/srkan/},
  note    = {PyPI package}
}

License

MIT โ€” free for academic and commercial use.

Made with ๐Ÿงฌ and way too little sleep

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

srkan-0.1.1.tar.gz (11.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

srkan-0.1.1-py3-none-any.whl (9.1 kB view details)

Uploaded Python 3

File details

Details for the file srkan-0.1.1.tar.gz.

File metadata

  • Download URL: srkan-0.1.1.tar.gz
  • Upload date:
  • Size: 11.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.9

File hashes

Hashes for srkan-0.1.1.tar.gz
Algorithm Hash digest
SHA256 8fdf8a5587bc5fe856618af885e1d2eb7241dc99e9390cef4aada85767f746ce
MD5 ae95efe519ac02c13ce8bf7da2048515
BLAKE2b-256 9f4ff94fb409d7fb3864c1b5d05f5f87cfb50dbdeb42fc28768ff4bb1bad5584

See more details on using hashes here.

File details

Details for the file srkan-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: srkan-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 9.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.9

File hashes

Hashes for srkan-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 79b39db253a0b5331db2ccacc2e6cae6d3636321732b0b14dfd8fb0911a35bd9
MD5 e765187f8a741028267373b6491eeff8
BLAKE2b-256 3bb4c38ecbe8d56c31db8d412cb489049b42acabcfccd67311779bb26079b709

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page