Neural robust state space models in PyTorch
Project description
neural-ssm
PyTorch implementations of robust neural state-space models (SSMs), centered on:
- free parametrizations of L2-bounded linear dynamical systems
- Lipschitz-bounded static nonlinearities for robust deep SSM design
The mathematical details are in:
- Free Parametrization of L2-bounded State Space Models
https://arxiv.org/abs/2503.23818
Installation
Install from pip:
pip install neural-ssm
Install the latest GitHub version:
pip install git+https://github.com/LeoMassai/neural-ssm.git
Architecture and robustness recipe
Reading the figure from left to right:
- Input is projected by an encoder.
- A stack of SSL blocks is applied.
- Each block combines:
- a dynamic core (
lru,l2n, ortv) - a static nonlinearity (
LGLU,LMLP,GLU, ...) - a residual connection.
- a dynamic core (
- Output is projected by a decoder.
Main message: l2n and tv, when used with a Lipschitz-bounded nonlinearity such as LGLU, enable robust deep SSMs with prescribed L2 bound.
Main parametrizations
lru: inspired by "Resurrecting Linear Recurrences"; efficient and stable linear recurrent backbone.l2n: SSM with prescribed L2 bound via free parametrization.tv: time-varying selective SSM with prescribed L2 bound (paper in preparation).
All these parametrizations support both forward execution modes:
- parallel scan via
mode="scan" - standard recurrence loop via
mode="loop"
You select the mode at call time, e.g. model(u, mode="scan") or model(u, mode="loop").
Main SSM parameters
d_input: input feature dimension.d_output: output feature dimension.d_model: latent model dimension used inside each SSL block.d_state: internal recurrent state dimension.n_layers: number of stacked SSL blocks.param: parametrization of the recurrent unit (lru,l2n,tv, ...).ff: static nonlinearity type (GLU,MLP,LMLP,LGLU,TLIP).gamma: desired L_2 bound of the overall SSM. Ifgamma=None, it is trainable.
Where each component is in the code
- End-to-end wrapper (encoder, stack, decoder):
DeepSSMinsrc/neural_ssm/ssm/lru.py - Repeated SSM block (dynamic core + nonlinearity + residual):
SSLinsrc/neural_ssm/ssm/lru.py - Dynamic cores:
lru->LRUinsrc/neural_ssm/ssm/lru.pyl2n->Block2x2DenseL2SSMinsrc/neural_ssm/ssm/lru.pytv->RobustMambaDiagSSMinsrc/neural_ssm/ssm/mamba.py
- Static nonlinearities:
GLU,MLPinsrc/neural_ssm/static_layers/generic_layers.pyLGLU,LMLP,TLIPinsrc/neural_ssm/static_layers/lipschitz_mlps.py
- Parallel scan utilities:
src/neural_ssm/ssm/scan_utils.py
Quick tutorial
Tensor shapes and forward outputs
- Input tensor shape is
u: (B, L, d_input)where:B= batch sizeL= sequence lengthd_input= input dimension
- Output tensor shape is
y: (B, L, d_output). DeepSSMreturns two objects:y: the model output sequencestate: a list of recurrent states (one tensor per SSL block), useful for stateful calls.
State initialization in forward:
- You can pass
state=as a list with one initial state tensor for each SSL block. - If
stateis not provided (state=None), internal recurrent states are initialized to zero.
How to create and call a Deep SSM
Building and using the SSM is pretty easy:
import torch
from neural_ssm import DeepSSM
model = DeepSSM(
d_input=1,
d_output=1,
d_model=16,
d_state=16,
n_layers=4,
param="tv",
ff="LGLU",
gamma=2.0,
)
u = torch.randn(8, 200, 1) # (B, L, d_input)
y, state = model(u, mode="scan") # zero-initialized internal states
# Stateful call: pass one state per SSL block
u_next = torch.randn(8, 200, 1)
y_next, state = model(u_next, state=state, mode="scan")
Top-level API
DeepSSM,SSMConfigLRU,L2RU,lruz,PureLRUR,SimpleRNN- static layers re-exported in
neural_ssm.layers
Examples
Example and experiment scripts are available in Test_files/.
Citation
If you use this repository in research, please cite:
Free Parametrization of L2-bounded State Space Models
https://arxiv.org/abs/2503.23818
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file neural_ssm-0.29.tar.gz.
File metadata
- Download URL: neural_ssm-0.29.tar.gz
- Upload date:
- Size: 43.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a6aefcb858501be9fa99aea12e4fb10d168c293b6c7ac744cc90082748d47387
|
|
| MD5 |
28d10907207cbc35ee64a2d6339a5c4d
|
|
| BLAKE2b-256 |
bb07012524f8bed651d9c26b88fe20206fd3735a61bff0f45d1e96ee80cb83d9
|
File details
Details for the file neural_ssm-0.29-py3-none-any.whl.
File metadata
- Download URL: neural_ssm-0.29-py3-none-any.whl
- Upload date:
- Size: 44.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1e923c0c443b1d06459356fe61126e4f5764ff0bbb632a0428b59962d1e68084
|
|
| MD5 |
adfe565f762b876325f3d27861f84d02
|
|
| BLAKE2b-256 |
beaa3438b02f31c7c4cf2ca91d66c3a46e0b5cf5f1bfec3bcceda40f1b517216
|