Skip to main content

Neural robust state space models in PyTorch

Project description

neural-ssm

PyTorch implementations of state-space models (SSMs) with a built-in robustness certificate in the form of a tunable L2-bound. This is obtained by using:

  • free parametrizations of L2-bounded linear dynamical systems
  • Lipschitz-bounded static nonlinearities

The mathematical details are in:

Installation

Install from pip:

pip install neural-ssm

Install the latest GitHub version:

pip install git+https://github.com/LeoMassai/neural-ssm.git

Architecture and robustness recipe

Let's see what an SSM is more in detail

L2 DeepSSM architecture

Reading the figure from left to right:

  1. Input is projected by a linear encoder.
  2. A stack of SSL blocks is applied.
  3. Each block combines:
    • a dynamic recurrent core with different state-space parametrizations (lru, l2n, or tv)
    • a static nonlinearity (LGLU, LMLP, GLU, ...)
    • a residual connection.
  4. Output is projected by a linear decoder.

Main message: l2n and tv, when used with a Lipschitz-bounded nonlinearity such as LGLU, enable robust deep SSMs with prescribed L2 bound.

Main parametrizations

  • lru: inspired by "Resurrecting Linear Recurrences"; parametrizes all and only stable LTI systems.
  • l2n: free parametrization of all and only LTI systems with a prescribed L2 bound.
  • tv: free parametrization of a time-varying selective recurrent unit with prescribed L2 bound (paper in preparation).

All these parametrizations support both forward execution modes:

  • parallel scan via mode="scan" (tipically very fast for long sequences)
  • standard recurrence loop via mode="loop"

You select the mode at call time, e.g. model(u, mode="scan") or model(u, mode="loop").

Main SSM parameters

The SSM model is implemented by the class DeepSSM, which takes a number of input parameters, here are the most important ones:

  • d_input: input feature dimension.
  • d_output: output feature dimension.
  • d_model: latent model dimension used inside each SSL block.
  • d_state: internal recurrent state dimension.
  • n_layers: number of stacked SSL blocks.
  • param: parametrization of the recurrent unit (lru, l2n, tv, ...).
  • ff: static nonlinearity type, same for each SSL block (GLU, MLP, LMLP, LGLU, TLIP).
  • gamma: desired L_2 bound of the overall SSM. If gamma=None, it is trainable.

Where each component is in the code

  • End-to-end wrapper (encoder, stack, decoder):
    DeepSSM in src/neural_ssm/ssm/lru.py
  • Repeated SSM block (dynamic core + nonlinearity + residual):
    SSL in src/neural_ssm/ssm/lru.py
  • Dynamic cores:
    • lru -> LRU in src/neural_ssm/ssm/lru.py
    • l2n -> Block2x2DenseL2SSM in src/neural_ssm/ssm/lru.py
    • tv -> RobustMambaDiagSSM in src/neural_ssm/ssm/mamba.py
  • Static nonlinearities:
    • GLU, MLP in src/neural_ssm/static_layers/generic_layers.py
    • LGLU, LMLP, TLIP in src/neural_ssm/static_layers/lipschitz_mlps.py
  • Parallel scan utilities:
    src/neural_ssm/ssm/scan_utils.py

Quick tutorial

For a complete, runnable training example on a nonlinear benchmark dataset, see: Test_files/Tutorial_DeepSSM.py

For a minimal Long Range Arena (LRA) benchmark script (ListOps) using DeepSSM, see: scripts/run_lra_listops.py (requires pip install datasets). The script defaults to the Hugging Face dataset lra-benchmarks with config listops.

Tensor shapes and forward outputs

  • Input tensor shape is u: (B, L, d_input) where:
    • B = batch size
    • L = sequence length
    • d_input = input dimension
  • Output tensor shape is y: (B, L, d_output).
  • DeepSSM returns two objects:
    • y: the model output sequence
    • state: a list of recurrent states (one tensor per SSL block), useful for stateful calls.

State handling in forward (DeepSSM and recurrent cores):

  • state: optional explicit initial state (or state list for DeepSSM).
  • reset_state: defaults to True.
    • True: start from zero state for this call.
    • False: reuse internal state if state is not provided.
  • detach_state: defaults to True.
    • True: internal state is detached at the end of the call (no cross-call BPTT).
    • False: internal state keeps gradient history (for true BPTT across calls).
  • reset(): manually resets internal state to zero.

How to create and call a Deep SSM

Building and using the SSM is pretty easy:

import torch
from neural_ssm import DeepSSM

model = DeepSSM(
    d_input=1,
    d_output=1,
    d_model=16,
    d_state=16,
    n_layers=4,
    param="tv",
    ff="LGLU",
    gamma=2.0,
)

u = torch.randn(8, 200, 1)               # (B, L, d_input)
y, state = model(u, mode="scan")         # reset_state=True by default

# Stateful call: pass one state per SSL block
u_next = torch.randn(8, 200, 1)
y_next, state = model(u_next, state=state, mode="scan", reset_state=False)

# Stateful call without passing explicit state (reuse internal state)
y_stream, state = model(u_next, mode="scan", reset_state=False)

# True BPTT across calls: keep state in graph
y_bptt, state = model(u_next, mode="scan", reset_state=False, detach_state=False)

Top-level API

  • DeepSSM, SSMConfig
  • LRU, L2RU, lruz
  • static layers re-exported in neural_ssm.layers

Examples

Example and experiment scripts are available in Test_files/, including:

  • Test_files/Tutorial_DeepSSM.py: minimal end-to-end DeepSSM training tutorial.

Citation

If you use this repository in research, please cite:

Free Parametrization of L2-bounded State Space Models
https://arxiv.org/abs/2503.23818

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

neural_ssm-0.33.tar.gz (46.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

neural_ssm-0.33-py3-none-any.whl (46.7 kB view details)

Uploaded Python 3

File details

Details for the file neural_ssm-0.33.tar.gz.

File metadata

  • Download URL: neural_ssm-0.33.tar.gz
  • Upload date:
  • Size: 46.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.9

File hashes

Hashes for neural_ssm-0.33.tar.gz
Algorithm Hash digest
SHA256 52750d9d60758307ee6439cd3d717e6107416b511eab824139ec9022873d3ec5
MD5 06526791eb7f9d3bd4a5caacc656aa89
BLAKE2b-256 95a9de3032a0f8bc09eced6180e07af99e514bbd8e6addd41cbe5ffbc921a914

See more details on using hashes here.

File details

Details for the file neural_ssm-0.33-py3-none-any.whl.

File metadata

  • Download URL: neural_ssm-0.33-py3-none-any.whl
  • Upload date:
  • Size: 46.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.9

File hashes

Hashes for neural_ssm-0.33-py3-none-any.whl
Algorithm Hash digest
SHA256 f142b74014b09ede318ffca64fc798aed515ffeeae2f2a456372d81d24b81f27
MD5 3d8d4f617bbb59f52c2567405efc57ba
BLAKE2b-256 a3a30de6669e31612f2726d6cf053ffd257978c08808f70b281d239b445d334e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page