Spiking Neural Network library built natively on Apple MLX
Project description
mlx-snn
The first Spiking Neural Network library built natively on Apple MLX.
mlx-snn brings SNN research to Apple Silicon. It provides spiking neuron models, surrogate gradient training, and spike encoding — all implemented with MLX for unified memory and lazy evaluation on M-series chips.
Installation
pip install mlx-snn
Requires Python 3.9+ and Apple Silicon (M1/M2/M3/M4).
Quick Start
import mlx.core as mx
import mlx.nn as nn
import mlxsnn
# Build a spiking network
fc = nn.Linear(784, 10)
lif = mlxsnn.Leaky(beta=0.95, threshold=1.0)
# Encode input as spike train and run over time
spikes_in = mlxsnn.rate_encode(mx.random.uniform(shape=(8, 784)), num_steps=25)
state = lif.init_state(batch_size=8, features=10)
for t in range(25):
spk, state = lif(fc(spikes_in[t]), state)
print("Output membrane:", state["mem"].shape) # (8, 10)
Features
Neuron Models
| Model | Status | Description |
|---|---|---|
| Leaky (LIF) | v0.1 | Leaky Integrate-and-Fire with configurable decay |
| IF | v0.1 | Integrate-and-Fire (non-leaky) |
| Izhikevich | v0.2 | 2D dynamics with RS/IB/CH/FS presets |
| Adaptive LIF | v0.2 | LIF with adaptive threshold |
| Synaptic | v0.2 | Conductance-based dual-state LIF |
| Alpha | v0.2 | Dual-exponential synaptic model |
| Liquid State Machine | coming soon | Reservoir computing with spiking neurons |
Surrogate Gradients
All neuron models support differentiable training via surrogate gradients:
- Fast Sigmoid — default, good balance of speed and accuracy
- Arctan — smoother gradient landscape
- Straight-Through Estimator — simplest, pass-through in a window
- Custom — plug in any smooth approximation
Spike Encoding
| Method | Status | Use Case |
|---|---|---|
| Rate (Poisson) | v0.1 | Static images, general-purpose |
| Latency (TTFS) | v0.1 | Energy-efficient, temporal coding |
| Delta Modulation | v0.2 | Temporal signals, change detection |
| EEG Encoder | v0.2 | EEG-to-spike with frequency band support |
| fMRI BOLD Encoder | coming soon | fMRI signal encoding with HRF handling |
Training
- BPTT forward pass helper (
bptt_forward) - SNN loss functions:
rate_coding_loss,membrane_loss,mse_count_loss - Works with standard MLX optimizers (
mlx.optimizers.Adam, etc.)
NIR Interoperability
NIR (Neuromorphic Intermediate Representation) enables cross-framework SNN model exchange. mlx-snn is the first MLX-native framework in the NIR ecosystem.
pip install mlx-snn[nir]
Export an mlx-snn model to NIR:
import mlx.nn as nn
import mlxsnn, nir
layers = [
('fc1', nn.Linear(784, 128)),
('lif1', mlxsnn.Leaky(beta=0.9)),
('fc2', nn.Linear(128, 10)),
('lif2', mlxsnn.Leaky(beta=0.9)),
]
graph = mlxsnn.export_to_nir(layers)
nir.write('model.nir', graph)
Import a NIR model into mlx-snn:
graph = nir.read('model.nir')
model = mlxsnn.import_from_nir(graph)
state = model.init_states(batch_size=32)
out, state = model(x, state)
Supported conversions: nn.Linear ↔ nir.Affine/nir.Linear, Leaky ↔ nir.LIF, IF ↔ nir.IF, Synaptic ↔ nir.CubaLIF.
Migrating from snnTorch
mlx-snn is designed to feel familiar to snnTorch users:
# snnTorch # mlx-snn
import snntorch as snn import mlxsnn
lif = snn.Leaky(beta=0.9) lif = mlxsnn.Leaky(beta=0.9)
spk, mem = lif(x, mem) spk, state = lif(x, state)
# state["mem"] == mem
Key differences:
- State is a dict, not separate tensors — plays well with MLX functional transforms
- No global hidden state — state is always explicit (pass in, get out)
- MLX arrays instead of PyTorch tensors — use
mx.array, nottorch.Tensor - Surrogate gradients use the STE pattern with
mx.stop_gradient— noautograd.Functionneeded
Architecture
mlxsnn/
├── neurons/ # SpikingNeuron base, Leaky, IF, Izhikevich, ALIF, Synaptic, Alpha
├── surrogate/ # fast_sigmoid, arctan, straight_through, custom
├── encoding/ # rate, latency, delta, EEG encoder
├── functional/ # Stateless pure functions (lif_step, if_step, fire, reset)
├── training/ # BPTT helpers, loss functions
└── utils/ # Visualization, metrics (coming soon)
Roadmap
- v0.1 — LIF/IF neurons, surrogate gradients, rate/latency encoding, MNIST example
- v0.2 — Izhikevich, ALIF, Synaptic, Alpha neurons, EEG encoder, delta encoding
- v0.2.1 — Fix fast sigmoid surrogate to match snnTorch rational approximation (97%+ MNIST accuracy)
- v0.3 — NIR interoperability (export/import), cross-framework SNN model exchange
- v0.4 — Liquid State Machine, reservoir topology,
mx.compileoptimization - v1.0 — Full docs, benchmarks, JOSS paper, numerical validation vs snnTorch
Citation
If you use mlx-snn in your research, please cite:
@software{mlxsnn2025,
title = {mlx-snn: Spiking Neural Networks on Apple Silicon via MLX},
author = {Qin, Jiahao},
year = {2025},
version = {0.3.0},
url = {https://github.com/D-ST-Sword/mlx-snn},
note = {https://pypi.org/project/mlx-snn/}
}
License
MIT
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mlx_snn-0.3.0.tar.gz.
File metadata
- Download URL: mlx_snn-0.3.0.tar.gz
- Upload date:
- Size: 35.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f494371dabda1ad03ed7d80573511cea0bc7fa9254db6fc8f1af684014d64efd
|
|
| MD5 |
1d6aaa330d5675b7e831a66b81f4feae
|
|
| BLAKE2b-256 |
c5cfef6376be8ebf72b19a4aa26b66b35f1fcee885a5f6f49ecb8d55b5fa7a6b
|
File details
Details for the file mlx_snn-0.3.0-py3-none-any.whl.
File metadata
- Download URL: mlx_snn-0.3.0-py3-none-any.whl
- Upload date:
- Size: 38.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ebb670f764266fe24e0c35ebcf340c412615559c86492444a11e5df2efd5d5f9
|
|
| MD5 |
5c3f696d06708f2f10ad314c835fcaa7
|
|
| BLAKE2b-256 |
f47866ae7753dfff584b469bdf2ab2b19cf027e7f69e8a2acc6d8da1f01eb2a3
|