Skip to main content

Spiking Neural Network library built natively on Apple MLX

Project description

mlx-snn

A general-purpose Spiking Neural Network library built on Apple MLX.

mlx-snn aims to provide an efficient, research-friendly SNN framework that leverages MLX's unified memory architecture and lazy evaluation. Whether you're exploring neuron dynamics, training classifiers with surrogate gradients, or exchanging models via NIR, mlx-snn offers a clean, Pythonic API that integrates naturally into the MLX ecosystem.

PyPI version Python 3.9+ License: GPL-3.0

Why mlx-snn?

  • MLX-native — All operations use mlx.core. No PyTorch/CUDA dependency. Runs on Apple Silicon with zero-copy unified memory.
  • Research-friendly — Explicit state dicts, composable surrogate gradients, and standard mlx.nn.Module patterns make it easy to experiment and extend.
  • Cross-framework — NIR support lets you import and export models to/from snnTorch, Norse, SpikingJelly, and neuromorphic hardware platforms.
  • Hardware tested — Currently validated on Apple M3 Max. Future Apple Silicon releases will be tested and supported as they become available.

Installation

pip install mlx-snn

Requires Python 3.9+ and Apple Silicon (M1/M2/M3/M4).

Quick Start

import mlx.core as mx
import mlx.nn as nn
import mlxsnn

# Build a spiking network
fc = nn.Linear(784, 10)
lif = mlxsnn.Leaky(beta=0.95, threshold=1.0)

# Encode input as spike train and run over time
spikes_in = mlxsnn.rate_encode(mx.random.uniform(shape=(8, 784)), num_steps=25)
state = lif.init_state(batch_size=8, features=10)

for t in range(25):
    spk, state = lif(fc(spikes_in[t]), state)

print("Output membrane:", state["mem"].shape)  # (8, 10)

Features

Neuron Models

Model Since Description
Leaky (LIF) v0.1 Leaky Integrate-and-Fire with configurable decay
IF v0.1 Integrate-and-Fire (non-leaky)
Izhikevich v0.2 2D dynamics with RS/IB/CH/FS presets
Adaptive LIF v0.2 LIF with adaptive threshold
Synaptic v0.2 Conductance-based dual-state LIF
Alpha v0.2 Dual-exponential synaptic model
RLeaky v0.4 Recurrent LIF with learnable feedback weight
RSynaptic v0.4 Recurrent Synaptic with learnable feedback weight

Surrogate Gradients

All neuron models support differentiable training via surrogate gradients:

  • Fast Sigmoid — default, good balance of speed and accuracy
  • Arctan — smoother gradient landscape
  • Sigmoid — standard logistic sigmoid derivative
  • Triangular (Tent) — localized, compact support near threshold
  • Straight-Through Estimator — simplest, unit gradient everywhere
  • Custom — plug in any smooth approximation

Spike Encoding

Method Since Use Case
Rate (Poisson) v0.1 Static images, general-purpose
Latency (TTFS) v0.1 Energy-efficient, temporal coding
Delta Modulation v0.2 Temporal signals, change detection
EEG Encoder v0.2 EEG-to-spike with frequency band support

Training & Loss Functions

  • BPTT forward pass helper (bptt_forward)
  • Loss functions: ce_rate_loss, ce_count_loss, mse_membrane_loss, membrane_loss, rate_coding_loss
  • Learnable parameters: learn_beta, learn_threshold, learn_V on all neurons
  • Works with standard MLX optimizers (mlx.optimizers.Adam, etc.)

NIR Interoperability

NIR (Neuromorphic Intermediate Representation) enables cross-framework SNN model exchange between simulators and neuromorphic hardware platforms.

pip install mlx-snn[nir]

Export an mlx-snn model to NIR:

import mlx.nn as nn
import mlxsnn, nir

layers = [
    ('fc1', nn.Linear(784, 128)),
    ('lif1', mlxsnn.Leaky(beta=0.9)),
    ('fc2', nn.Linear(128, 10)),
    ('lif2', mlxsnn.Leaky(beta=0.9)),
]
graph = mlxsnn.export_to_nir(layers)
nir.write('model.nir', graph)

Import a NIR model into mlx-snn:

graph = nir.read('model.nir')
model = mlxsnn.import_from_nir(graph)
state = model.init_states(batch_size=32)
out, state = model(x, state)

Supported conversions: nn.Linear <-> nir.Affine/nir.Linear, Leaky <-> nir.LIF, IF <-> nir.IF, Synaptic <-> nir.CubaLIF.

Benchmark Highlights

Experiments on MNIST (784-128-10 SNN, 25 timesteps, 5 seeds) on Apple M3 Max, compared with snnTorch on NVIDIA V100:

Configuration mlx-snn (M3 Max) snnTorch (V100) Speed (mlx-snn) Speed (snnTorch)
Leaky (LIF) 96.3% 97.3% 5.7 s/epoch 20.9 s/epoch
Synaptic 94.4% 95.8% 6.1 s/epoch 25.2 s/epoch
RLeaky (V=0.1, learn) 91.6% 68.1% 6.8 s/epoch 25.7 s/epoch
RSynaptic (V=0.1, learn) 89.0% 52.2% 7.3 s/epoch 29.2 s/epoch
Fast Sigmoid surrogate 96.3% 96.7% 5.7 s/epoch 20.9 s/epoch
Triangular (Tent) surrogate 86.0% 50.8% 10.9 s/epoch 20.9 s/epoch

mlx-snn achieves ~3.7-4.1x faster training per epoch on the M3 Max compared to the V100, while maintaining competitive accuracy. Recurrent neurons with learnable weights significantly outperform snnTorch's default configurations.

For full results, see our benchmarking paper and the experiments/ directory.

Migrating from snnTorch

mlx-snn is designed to feel familiar to snnTorch users:

# snnTorch                          # mlx-snn
import snntorch as snn              import mlxsnn
lif = snn.Leaky(beta=0.9)          lif = mlxsnn.Leaky(beta=0.9)
spk, mem = lif(x, mem)             spk, state = lif(x, state)
                                    # state["mem"] == mem

Key differences:

  • State is a dict, not separate tensors — plays well with MLX functional transforms
  • No global hidden state — state is always explicit (pass in, get out)
  • MLX arrays instead of PyTorch tensors — use mx.array, not torch.Tensor
  • Surrogate gradients use the STE pattern with mx.stop_gradient

Project Structure

mlxsnn/
├── neurons/       # SpikingNeuron base, Leaky, IF, Izhikevich, ALIF, Synaptic, Alpha, RLeaky, RSynaptic
├── surrogate/     # fast_sigmoid, arctan, sigmoid, triangular, straight_through, custom
├── encoding/      # rate, latency, delta, EEG encoder
├── functional/    # Stateless pure functions (lif_step, if_step, fire, reset)
├── training/      # BPTT helpers, loss functions
└── nir_*.py       # NIR export/import utilities

Roadmap

  • v0.1 — Core neurons (LIF, IF), surrogate gradients, rate/latency encoding
  • v0.2 — Extended neurons (Izhikevich, ALIF, Synaptic, Alpha), EEG encoder, delta encoding
  • v0.3 — NIR interoperability (export/import)
  • v0.4 — Recurrent neurons (RLeaky, RSynaptic), learnable thresholds, expanded surrogates and losses
  • v0.5 — Liquid State Machine, reservoir topology, mx.compile optimization
  • v1.0 — Full documentation, comprehensive benchmarks, neuromorphic dataset loaders

Publications

Citation

If you use mlx-snn in your research, please cite:

@misc{qin2026mlxsnn,
  title         = {mlx-snn: Spiking Neural Networks on Apple Silicon via {MLX}},
  author        = {Jiahao Qin},
  year          = {2026},
  eprint        = {2603.03529},
  archivePrefix = {arXiv},
  primaryClass  = {cs.LG},
  url           = {https://arxiv.org/abs/2603.03529}
}

Contributing

Contributions are welcome! Please open an issue or pull request on GitHub.

License

GPL-3.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlx_snn-0.5.0.tar.gz (91.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mlx_snn-0.5.0-py3-none-any.whl (100.1 kB view details)

Uploaded Python 3

File details

Details for the file mlx_snn-0.5.0.tar.gz.

File metadata

  • Download URL: mlx_snn-0.5.0.tar.gz
  • Upload date:
  • Size: 91.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.18

File hashes

Hashes for mlx_snn-0.5.0.tar.gz
Algorithm Hash digest
SHA256 5442bd34c1c7304e84e10e69be62c036ab909de12cddf931f0f7ad4d2f1ee729
MD5 fadfae604f1a974e59c285f3630642d7
BLAKE2b-256 a05cf02a168bd2be2810e31409597a2e87f6b7e05d3aa977f845a6979d9fd7de

See more details on using hashes here.

File details

Details for the file mlx_snn-0.5.0-py3-none-any.whl.

File metadata

  • Download URL: mlx_snn-0.5.0-py3-none-any.whl
  • Upload date:
  • Size: 100.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.18

File hashes

Hashes for mlx_snn-0.5.0-py3-none-any.whl
Algorithm Hash digest
SHA256 5534fb830ad79cec8be352316c8dfea01cc9738871f1964bf1b2c55c117ec020
MD5 69040abe6d264d3993003cb5cdd848d6
BLAKE2b-256 224534bee1f2496d7bd6d7966e66a16846afa14117a5c59aa1936eb7c635960a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page