Spiking Rational Kolmogorov-Arnold Network
Project description
๐ง SRKAN
Spiking Recurrent Kolmogorov-Arnold Networks
A bio-inspired neural architecture that fuses the temporal expressiveness of Spiking Neural Networks with the function-learning power of Kolmogorov-Arnold Networks.
Installation | Quick Start | Architecture | Benchmarks | API Reference
Why SRKAN?
- Standard SNNs suffer from surrogate gradient bias โ the mismatch between forward spikes and backward gradients causes training instability in deep networks.
- Standard KANs have no temporal memory โ they process each timestep independently, making them blind to spike timing patterns.
SRKAN solves both problems simultaneously:
- Saltation gradient โ a biologically-motivated surrogate that tracks sub-threshold membrane dynamics, reducing gradient mismatch
- Dendritic KAN gate โ replaces linear synaptic weights with learnable B-spline functions, letting each synapse learn its own nonlinear transfer curve
- AdaptiveLIF + EMA deadband homeostasis โ prevents firing-rate collapse in deep stacks (the key training stability fix)
- Chunked parallel scan โ O(โT) memory complexity for long spike trains, enabling 28ms inference on standard GPU
Installation
pip install srkan
Requirements: Python โฅ 3.9, PyTorch โฅ 2.0, pykan
Quick Start
import torch
from srkan import SRKAN
# Build model โ 700 input channels (SHD cochlear), 20 classes
model = SRKAN(
n_in = 700,
n_hidden = 256,
n_out = 20,
n_layers = 4,
T = 100, # timesteps
grid = 5, # B-spline grid resolution
chunk_size= 10 # chunked scan window
)
# Input: [Batch, Time, Channels] binary spike tensor
x = torch.randint(0, 2, (32, 100, 700)).float()
# Classification logits via membrane voltage readout
logits = model.readout_membrane(x).mean(dim=1) # [B, n_out]
loss = torch.nn.functional.cross_entropy(logits, targets)
Note: Use model.readout_membrane(x).mean(dim=1) for classification โ not model(x). The membrane voltage is differentiable everywhere and preserves full sub-threshold timing signal.
Architecture
Input Spikes [B, T, n_in]
โ
โผ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Pre-projection (Linear) โ 784 โ 64 (reduces KAN OOM risk)
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ
โผ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ Modal Synapse (B-spline KAN) โ โ
โ Dendritic Gate (gated signal) โ โ ร n_layers
โ AdaptiveLIF (spike + EMA) โ โ
โ Chunked Scan (O(โT) memory) โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ
โผ
Readout Layer โ Membrane Voltage [B, T, n_out]
โ .mean(dim=1)
โผ
Logits [B, n_out]
Key Components
| Component | Role | Innovation |
|---|---|---|
| ModalSynapse | B-spline synaptic transfer | Replaces linear weights with learnable nonlinear curves |
| DendriticGate | Gated signal mixing | Separates fast/slow dendritic compartments |
| SaltationGradient | Surrogate for backprop | Tracks sub-threshold dynamics, reduces bias vs. rectangle surrogate |
| AdaptiveLIF | Leaky Integrate-and-Fire | EMA-based threshold adaptation with deadband to prevent collapse |
| ChunkedScan | Temporal recurrence | Parallel scan in O(โT) memory vs. O(T) for naive RNN |
Benchmarks
Evaluated on the Spiking Heidelberg Digits (SHD) dataset โ 700-channel cochlear spike recordings, 20 spoken digit classes, T=100 timesteps.
Accuracy
| Model | Test Accuracy | Parameters |
|---|---|---|
| KAN (feed-forward) | 18.9% | 132K |
| MLP | 49.6% | 18.2M |
| RNN | 70.2% | 520K |
| SRKAN (ours) | 83.4% | 34.8K |
SRKAN achieves +33.8 points over MLP with 500ร fewer parameters.
Training Stability (Homeostasis Ablation)
| Variant | Best Acc | Final Acc | Behavior |
|---|---|---|---|
| Naive homeostasis | 50.6% | 31.6% | 19-pt collapse โ |
| EMA + deadband (SRKAN) | 52.8% | 48.8% | Stable โ |
The EMA deadband homeostasis is the critical stability fix โ without it, deep SNN stacks collapse during training regardless of surrogate choice.
Inference Speed
| Model | Latency (ms/sample) |
|---|---|
| KAN | 0.31 |
| RNN | 0.09 |
| MLP | 0.06 |
| SRKAN | 0.28 |
API Reference
SRKAN(n_in, n_hidden, n_out, n_layers, T, grid, chunk_size)
| Parameter | Type | Default | Description |
|---|---|---|---|
n_in |
int | โ | Input channels (e.g. 700 for SHD) |
n_hidden |
int | 256 | Hidden layer width |
n_out |
int | โ | Number of output classes |
n_layers |
int | 4 | Number of SRKAN blocks |
T |
int | 100 | Number of timesteps |
grid |
int | 5 | B-spline grid resolution |
chunk_size |
int | 10 | Chunked scan window (โT recommended) |
Methods
model.forward(x) # โ binary spikes [B, T, n_out] โ hidden layers only
model.readout_membrane(x) # โ membrane voltage [B, T, n_out] โ use for classification
Training Example (SHD)
import torch
from srkan import SRKAN
model = SRKAN(n_in=700, n_hidden=256, n_out=20, n_layers=4, T=100)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for spikes, labels in train_loader: # spikes: [B, T, 700]
logits = model.readout_membrane(spikes).mean(dim=1)
loss = torch.nn.functional.cross_entropy(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
Citation
If SRKAN is useful in your research, please cite:
@software{srkan2026,
title = {SRKAN: Spiking Recurrent Kolmogorov-Arnold Networks},
year = {2026},
url = {https://pypi.org/project/srkan/},
note = {PyPI package}
}
License
MIT โ free for academic and commercial use.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file srkan-0.1.1.tar.gz.
File metadata
- Download URL: srkan-0.1.1.tar.gz
- Upload date:
- Size: 11.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8fdf8a5587bc5fe856618af885e1d2eb7241dc99e9390cef4aada85767f746ce
|
|
| MD5 |
ae95efe519ac02c13ce8bf7da2048515
|
|
| BLAKE2b-256 |
9f4ff94fb409d7fb3864c1b5d05f5f87cfb50dbdeb42fc28768ff4bb1bad5584
|
File details
Details for the file srkan-0.1.1-py3-none-any.whl.
File metadata
- Download URL: srkan-0.1.1-py3-none-any.whl
- Upload date:
- Size: 9.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
79b39db253a0b5331db2ccacc2e6cae6d3636321732b0b14dfd8fb0911a35bd9
|
|
| MD5 |
e765187f8a741028267373b6491eeff8
|
|
| BLAKE2b-256 |
3bb4c38ecbe8d56c31db8d412cb489049b42acabcfccd67311779bb26079b709
|