Skip to main content

Spiking Neural Network library built natively on Apple MLX

Project description

mlx-snn

The first Spiking Neural Network library built natively on Apple MLX.

mlx-snn brings SNN research to Apple Silicon. It provides spiking neuron models, surrogate gradient training, and spike encoding — all implemented with MLX for unified memory and lazy evaluation on M-series chips.

PyPI version Python 3.9+ License: MIT

Installation

pip install mlx-snn

Requires Python 3.9+ and Apple Silicon (M1/M2/M3/M4).

Quick Start

import mlx.core as mx
import mlx.nn as nn
import mlxsnn

# Build a spiking network
fc = nn.Linear(784, 10)
lif = mlxsnn.Leaky(beta=0.95, threshold=1.0)

# Encode input as spike train and run over time
spikes_in = mlxsnn.rate_encode(mx.random.uniform(shape=(8, 784)), num_steps=25)
state = lif.init_state(batch_size=8, features=10)

for t in range(25):
    spk, state = lif(fc(spikes_in[t]), state)

print("Output membrane:", state["mem"].shape)  # (8, 10)

Features

Neuron Models

Model Status Description
Leaky (LIF) v0.1 Leaky Integrate-and-Fire with configurable decay
IF v0.1 Integrate-and-Fire (non-leaky)
Izhikevich v0.2 2D dynamics with RS/IB/CH/FS presets
Adaptive LIF v0.2 LIF with adaptive threshold
Synaptic v0.2 Conductance-based dual-state LIF
Alpha v0.2 Dual-exponential synaptic model
Liquid State Machine coming soon Reservoir computing with spiking neurons

Surrogate Gradients

All neuron models support differentiable training via surrogate gradients:

  • Fast Sigmoid — default, good balance of speed and accuracy
  • Arctan — smoother gradient landscape
  • Straight-Through Estimator — simplest, pass-through in a window
  • Custom — plug in any smooth approximation

Spike Encoding

Method Status Use Case
Rate (Poisson) v0.1 Static images, general-purpose
Latency (TTFS) v0.1 Energy-efficient, temporal coding
Delta Modulation v0.2 Temporal signals, change detection
EEG Encoder v0.2 EEG-to-spike with frequency band support
fMRI BOLD Encoder coming soon fMRI signal encoding with HRF handling

Training

  • BPTT forward pass helper (bptt_forward)
  • SNN loss functions: rate_coding_loss, membrane_loss, mse_count_loss
  • Works with standard MLX optimizers (mlx.optimizers.Adam, etc.)

NIR Interoperability

NIR (Neuromorphic Intermediate Representation) enables cross-framework SNN model exchange. mlx-snn is the first MLX-native framework in the NIR ecosystem.

pip install mlx-snn[nir]

Export an mlx-snn model to NIR:

import mlx.nn as nn
import mlxsnn, nir

layers = [
    ('fc1', nn.Linear(784, 128)),
    ('lif1', mlxsnn.Leaky(beta=0.9)),
    ('fc2', nn.Linear(128, 10)),
    ('lif2', mlxsnn.Leaky(beta=0.9)),
]
graph = mlxsnn.export_to_nir(layers)
nir.write('model.nir', graph)

Import a NIR model into mlx-snn:

graph = nir.read('model.nir')
model = mlxsnn.import_from_nir(graph)
state = model.init_states(batch_size=32)
out, state = model(x, state)

Supported conversions: nn.Linearnir.Affine/nir.Linear, Leakynir.LIF, IFnir.IF, Synapticnir.CubaLIF.

Migrating from snnTorch

mlx-snn is designed to feel familiar to snnTorch users:

# snnTorch                          # mlx-snn
import snntorch as snn              import mlxsnn
lif = snn.Leaky(beta=0.9)          lif = mlxsnn.Leaky(beta=0.9)
spk, mem = lif(x, mem)             spk, state = lif(x, state)
                                    # state["mem"] == mem

Key differences:

  • State is a dict, not separate tensors — plays well with MLX functional transforms
  • No global hidden state — state is always explicit (pass in, get out)
  • MLX arrays instead of PyTorch tensors — use mx.array, not torch.Tensor
  • Surrogate gradients use the STE pattern with mx.stop_gradient — no autograd.Function needed

Architecture

mlxsnn/
├── neurons/       # SpikingNeuron base, Leaky, IF, Izhikevich, ALIF, Synaptic, Alpha
├── surrogate/     # fast_sigmoid, arctan, straight_through, custom
├── encoding/      # rate, latency, delta, EEG encoder
├── functional/    # Stateless pure functions (lif_step, if_step, fire, reset)
├── training/      # BPTT helpers, loss functions
└── utils/         # Visualization, metrics (coming soon)

Roadmap

  • v0.1 — LIF/IF neurons, surrogate gradients, rate/latency encoding, MNIST example
  • v0.2 — Izhikevich, ALIF, Synaptic, Alpha neurons, EEG encoder, delta encoding
  • v0.2.1 — Fix fast sigmoid surrogate to match snnTorch rational approximation (97%+ MNIST accuracy)
  • v0.3 — NIR interoperability (export/import), cross-framework SNN model exchange
  • v0.4 — Liquid State Machine, reservoir topology, mx.compile optimization
  • v1.0 — Full docs, benchmarks, JOSS paper, numerical validation vs snnTorch

Citation

If you use mlx-snn in your research, please cite:

@software{mlxsnn2025,
  title   = {mlx-snn: Spiking Neural Networks on Apple Silicon via MLX},
  author  = {Qin, Jiahao},
  year    = {2025},
  version = {0.3.0},
  url     = {https://github.com/D-ST-Sword/mlx-snn},
  note    = {https://pypi.org/project/mlx-snn/}
}

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlx_snn-0.3.0.tar.gz (35.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mlx_snn-0.3.0-py3-none-any.whl (38.6 kB view details)

Uploaded Python 3

File details

Details for the file mlx_snn-0.3.0.tar.gz.

File metadata

  • Download URL: mlx_snn-0.3.0.tar.gz
  • Upload date:
  • Size: 35.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.18

File hashes

Hashes for mlx_snn-0.3.0.tar.gz
Algorithm Hash digest
SHA256 f494371dabda1ad03ed7d80573511cea0bc7fa9254db6fc8f1af684014d64efd
MD5 1d6aaa330d5675b7e831a66b81f4feae
BLAKE2b-256 c5cfef6376be8ebf72b19a4aa26b66b35f1fcee885a5f6f49ecb8d55b5fa7a6b

See more details on using hashes here.

File details

Details for the file mlx_snn-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: mlx_snn-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 38.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.18

File hashes

Hashes for mlx_snn-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ebb670f764266fe24e0c35ebcf340c412615559c86492444a11e5df2efd5d5f9
MD5 5c3f696d06708f2f10ad314c835fcaa7
BLAKE2b-256 f47866ae7753dfff584b469bdf2ab2b19cf027e7f69e8a2acc6d8da1f01eb2a3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page