Skip to main content

Neural robust state space models in PyTorch

Project description

neural-ssm

PyTorch implementations of robust neural state-space models (SSMs), centered on:

  • free parametrizations of L2-bounded linear dynamical systems
  • Lipschitz-bounded static nonlinearities for robust deep SSM design

The mathematical details are in:

Installation

Install from pip:

pip install neural-ssm

Install the latest GitHub version:

pip install git+https://github.com/LeoMassai/neural-ssm.git

Architecture and robustness recipe

L2 DeepSSM architecture

Reading the figure from left to right:

  1. Input is projected by an encoder.
  2. A stack of SSL blocks is applied.
  3. Each block combines:
    • a dynamic core (lru, l2n, or tv)
    • a static nonlinearity (LGLU, LMLP, GLU, ...)
    • a residual connection.
  4. Output is projected by a decoder.

Main message: l2n and tv, when used with a Lipschitz-bounded nonlinearity such as LGLU, enable robust deep SSMs with prescribed L2 bound.

Main parametrizations

  • lru: inspired by "Resurrecting Linear Recurrences"; efficient and stable linear recurrent backbone.
  • l2n: SSM with prescribed L2 bound via free parametrization.
  • tv: time-varying selective SSM with prescribed L2 bound (paper in preparation).

All these parametrizations support both forward execution modes:

  • parallel scan via mode="scan"
  • standard recurrence loop via mode="loop"

You select the mode at call time, e.g. model(u, mode="scan") or model(u, mode="loop").

Main SSM parameters

  • d_input: input feature dimension.
  • d_output: output feature dimension.
  • d_model: latent model dimension used inside each SSL block.
  • d_state: internal recurrent state dimension.
  • n_layers: number of stacked SSL blocks.
  • param: parametrization of the recurrent unit (lru, l2n, tv, ...).
  • ff: static nonlinearity type (GLU, MLP, LMLP, LGLU, TLIP).
  • gamma: desired L_2 bound of the overall SSM. If gamma=None, it is trainable.

Where each component is in the code

  • End-to-end wrapper (encoder, stack, decoder):
    DeepSSM in src/neural_ssm/ssm/lru.py
  • Repeated SSM block (dynamic core + nonlinearity + residual):
    SSL in src/neural_ssm/ssm/lru.py
  • Dynamic cores:
    • lru -> LRU in src/neural_ssm/ssm/lru.py
    • l2n -> Block2x2DenseL2SSM in src/neural_ssm/ssm/lru.py
    • tv -> RobustMambaDiagSSM in src/neural_ssm/ssm/mamba.py
  • Static nonlinearities:
    • GLU, MLP in src/neural_ssm/static_layers/generic_layers.py
    • LGLU, LMLP, TLIP in src/neural_ssm/static_layers/lipschitz_mlps.py
  • Parallel scan utilities:
    src/neural_ssm/ssm/scan_utils.py

Quick tutorial

Tensor shapes and forward outputs

  • Input tensor shape is u: (B, L, d_input) where:
    • B = batch size
    • L = sequence length
    • d_input = input dimension
  • Output tensor shape is y: (B, L, d_output).
  • DeepSSM returns two objects:
    • y: the model output sequence
    • state: a list of recurrent states (one tensor per SSL block), useful for stateful calls.

State initialization in forward:

  • You can pass state= as a list with one initial state tensor for each SSL block.
  • If state is not provided (state=None), internal recurrent states are initialized to zero.

How to create and call a Deep SSM

Building and using the SSM is pretty easy:

import torch
from neural_ssm import DeepSSM

model = DeepSSM(
    d_input=1,
    d_output=1,
    d_model=16,
    d_state=16,
    n_layers=4,
    param="tv",
    ff="LGLU",
    gamma=2.0,
)

u = torch.randn(8, 200, 1)               # (B, L, d_input)
y, state = model(u, mode="scan")         # zero-initialized internal states

# Stateful call: pass one state per SSL block
u_next = torch.randn(8, 200, 1)
y_next, state = model(u_next, state=state, mode="scan")

Top-level API

  • DeepSSM, SSMConfig
  • LRU, L2RU, lruz, PureLRUR, SimpleRNN
  • static layers re-exported in neural_ssm.layers

Examples

Example and experiment scripts are available in Test_files/.

Citation

If you use this repository in research, please cite:

Free Parametrization of L2-bounded State Space Models
https://arxiv.org/abs/2503.23818

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

neural_ssm-0.29.tar.gz (43.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

neural_ssm-0.29-py3-none-any.whl (44.3 kB view details)

Uploaded Python 3

File details

Details for the file neural_ssm-0.29.tar.gz.

File metadata

  • Download URL: neural_ssm-0.29.tar.gz
  • Upload date:
  • Size: 43.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.9

File hashes

Hashes for neural_ssm-0.29.tar.gz
Algorithm Hash digest
SHA256 a6aefcb858501be9fa99aea12e4fb10d168c293b6c7ac744cc90082748d47387
MD5 28d10907207cbc35ee64a2d6339a5c4d
BLAKE2b-256 bb07012524f8bed651d9c26b88fe20206fd3735a61bff0f45d1e96ee80cb83d9

See more details on using hashes here.

File details

Details for the file neural_ssm-0.29-py3-none-any.whl.

File metadata

  • Download URL: neural_ssm-0.29-py3-none-any.whl
  • Upload date:
  • Size: 44.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.9

File hashes

Hashes for neural_ssm-0.29-py3-none-any.whl
Algorithm Hash digest
SHA256 1e923c0c443b1d06459356fe61126e4f5764ff0bbb632a0428b59962d1e68084
MD5 adfe565f762b876325f3d27861f84d02
BLAKE2b-256 beaa3438b02f31c7c4cf2ca91d66c3a46e0b5cf5f1bfec3bcceda40f1b517216

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page