Neural robust state space models in PyTorch
Project description
neural-ssm
PyTorch implementations of state-space models (SSMs) with a built-in robustness certificate in the form of a tunable L2-bound. This is obtained by using:
- free parametrizations of L2-bounded linear dynamical systems
- Lipschitz-bounded static nonlinearities
The mathematical details are in:
- Free Parametrization of L2-bounded State Space Models
https://arxiv.org/abs/2503.23818
Installation
Install from pip:
pip install neural-ssm
Install the latest GitHub version:
pip install git+https://github.com/LeoMassai/neural-ssm.git
Architecture and robustness recipe
Let's see what is an SSM more in detail
Reading the figure from left to right:
- Input is projected by an encoder.
- A stack of SSL blocks is applied.
- Each block combines:
- a dynamic core with different state-space parametrizations (
lru,l2n, ortv) - a static nonlinearity (
LGLU,LMLP,GLU, ...) - a residual connection.
- a dynamic core with different state-space parametrizations (
- Output is projected by a decoder.
Main message: l2n and tv, when used with a Lipschitz-bounded nonlinearity such as LGLU, enable robust deep SSMs with prescribed L2 bound.
Main parametrizations
lru: inspired by "Resurrecting Linear Recurrences"; parametrizes all and only stable LTI systems.l2n: free parametrization of all and only LTI systems with a prescribed L2 bound.tv: free parametrization of a time-varying selective recurrent unit with prescribed L2 bound (paper in preparation).
All these parametrizations support both forward execution modes:
- parallel scan via
mode="scan"(tipically very fast for long sequences) - standard recurrence loop via
mode="loop"
You select the mode at call time, e.g. model(u, mode="scan") or model(u, mode="loop").
Main SSM parameters
d_input: input feature dimension.d_output: output feature dimension.d_model: latent model dimension used inside each SSL block.d_state: internal recurrent state dimension.n_layers: number of stacked SSL blocks.param: parametrization of the recurrent unit (lru,l2n,tv, ...).ff: static nonlinearity type (GLU,MLP,LMLP,LGLU,TLIP).gamma: desired L_2 bound of the overall SSM. Ifgamma=None, it is trainable.
Where each component is in the code
- End-to-end wrapper (encoder, stack, decoder):
DeepSSMinsrc/neural_ssm/ssm/lru.py - Repeated SSM block (dynamic core + nonlinearity + residual):
SSLinsrc/neural_ssm/ssm/lru.py - Dynamic cores:
lru->LRUinsrc/neural_ssm/ssm/lru.pyl2n->Block2x2DenseL2SSMinsrc/neural_ssm/ssm/lru.pytv->RobustMambaDiagSSMinsrc/neural_ssm/ssm/mamba.py
- Static nonlinearities:
GLU,MLPinsrc/neural_ssm/static_layers/generic_layers.pyLGLU,LMLP,TLIPinsrc/neural_ssm/static_layers/lipschitz_mlps.py
- Parallel scan utilities:
src/neural_ssm/ssm/scan_utils.py
Quick tutorial
For a complete, runnable training example on a nonlinear benchmark dataset, see:
Test_files/Tutorial_DeepSSM.py
Tensor shapes and forward outputs
- Input tensor shape is
u: (B, L, d_input)where:B= batch sizeL= sequence lengthd_input= input dimension
- Output tensor shape is
y: (B, L, d_output). DeepSSMreturns two objects:y: the model output sequencestate: a list of recurrent states (one tensor per SSL block), useful for stateful calls.
State initialization in forward:
- You can pass
state=as a list with one initial state tensor for each SSL block. - If
stateis not provided (state=None), internal recurrent states are initialized to zero.
How to create and call a Deep SSM
Building and using the SSM is pretty easy:
import torch
from neural_ssm import DeepSSM
model = DeepSSM(
d_input=1,
d_output=1,
d_model=16,
d_state=16,
n_layers=4,
param="tv",
ff="LGLU",
gamma=2.0,
)
u = torch.randn(8, 200, 1) # (B, L, d_input)
y, state = model(u, mode="scan") # zero-initialized internal states
# Stateful call: pass one state per SSL block
u_next = torch.randn(8, 200, 1)
y_next, state = model(u_next, state=state, mode="scan")
Top-level API
DeepSSM,SSMConfigLRU,L2RU,lruz,PureLRUR,SimpleRNN- static layers re-exported in
neural_ssm.layers
Examples
Example and experiment scripts are available in Test_files/, including:
Test_files/Tutorial_DeepSSM.py: minimal end-to-end DeepSSM training tutorial.
Citation
If you use this repository in research, please cite:
Free Parametrization of L2-bounded State Space Models
https://arxiv.org/abs/2503.23818
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file neural_ssm-0.30.tar.gz.
File metadata
- Download URL: neural_ssm-0.30.tar.gz
- Upload date:
- Size: 43.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
42582dc1902a6340e85d8a629f0bc90ebde3c8d691bac2b31fa9afec74d83917
|
|
| MD5 |
3fcd06083b445ffb8f09ad4f62cd48f0
|
|
| BLAKE2b-256 |
d8d1b6ac372a058f9b289e31e79ded97fe4c624b4f0b5cbf4c7993c11aa68715
|
File details
Details for the file neural_ssm-0.30-py3-none-any.whl.
File metadata
- Download URL: neural_ssm-0.30-py3-none-any.whl
- Upload date:
- Size: 44.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bf5ff49c93ac80b705978089a1001fb77d44dde0f1da7bbce2b3c633f2b28f3e
|
|
| MD5 |
33aa86ccac196387d9207e7f580297c0
|
|
| BLAKE2b-256 |
224ad402c124e045671f4d784ff9385732671960ecc2eea2ace2961d94b37088
|