Skip to main content

Computation-aware State-Space Models

Project description

arXiv

CASSM

Lorenz Dynamics

Computation-aware state-space models for neural data.

CASSM provides PyTorch implementations of computation-aware filtering and smoothing models for high-dimensional neural time series. The package includes synthetic data utilities, model implementations, metrics, and a short tutorial notebook for training on Lorenz-generated spike trains.

Installation

Install from PyPI:

pip install cassm

Install the package in editable mode from the repository root:

pip install -e .

For development and tests, install the optional development dependencies:

pip install -e ".[dev]"

Quick Start

import torch
from torch.utils.data import DataLoader, TensorDataset

from cassm.datasets.synthetic_data import LorenzData
from cassm.models import CASSM
from cassm.utils.preprocessing import smooth_firing_rate

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_data, valid_data, _, valid_truth, _, _ = LorenzData(
    num_inits=4,
    neurons=120,
    num_trials=4,
    device="cpu",
    seed=2,
)

train_data = smooth_firing_rate(train_data.numpy()).to(device)
valid_data = smooth_firing_rate(valid_data.numpy()).to(device)
valid_truth = valid_truth.to(device)

train_loader = DataLoader(train_data, batch_size=4, shuffle=True)
test_loader = DataLoader(TensorDataset(valid_data, valid_truth), batch_size=4)

model = CASSM(
    projection_dim=10,
    nneurons=train_data.shape[-1],
    timesteps=train_data.shape[1],
    dt=0.01,
    device=device,
    save_model=False,
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=5e-2)
model.train_model(
    epochs=3,
    optimizer=optimizer,
    train_loader=train_loader,
    test_loader=test_loader,
    valid_truth=valid_truth,
    clip_value=300,
)

predicted_rate, predicted_noise = model.predict_rate(valid_data)

See tutorials/cassm_lorenz_tutorial.ipynb for a fuller walkthrough.

Repository Layout

  • src/cassm: package source code
  • tests: package tests
  • tutorials: user-facing notebooks
  • pyproject.toml: package metadata and tooling configuration

Testing

pytest

License

CASSM is released under the MIT License. See LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cassm-0.1.0.tar.gz (333.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cassm-0.1.0-py3-none-any.whl (57.6 kB view details)

Uploaded Python 3

File details

Details for the file cassm-0.1.0.tar.gz.

File metadata

  • Download URL: cassm-0.1.0.tar.gz
  • Upload date:
  • Size: 333.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.6

File hashes

Hashes for cassm-0.1.0.tar.gz
Algorithm Hash digest
SHA256 30d1cb49406bdef714b0c75b53f663b7f509c6ccc4fd5d6ccfb13abb9bb25e9c
MD5 ec5b9f5522fb6752e2e59550d2309bee
BLAKE2b-256 02272e0e4b7af77ecc8a14fed091dacd108e9f3915647dd7b8ad57ef60f8a0ff

See more details on using hashes here.

File details

Details for the file cassm-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: cassm-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 57.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.6

File hashes

Hashes for cassm-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 eef1887de0e63e2aa90f7db7efd2f596e0457f261df900e724cf8cc7439611f5
MD5 ed4b1ea89e18cc249ed46abcea7e37e0
BLAKE2b-256 8eb47de6c4d03461a4d708dffa61fdbaf4ad1ee76f440d019ddb325c6032d5cb

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page