Skip to main content

A PyTorch-based library and benchmark for fitting sensory neural responses with deep neural network models

Project description

deepSTRF

A PyTorch library for fitting sensory neural responses with deep neural network models

Documentation Status CI License: GPL v3 Python

Contact: Ulysse Rançon — @urancon, ulysse.rancon@uni-goettingen.de


🧠 Overview

deepSTRF is a community-oriented library for system identification of sensory neurons — predicting trial-resolved neural responses (spikes, calcium fluorescence, EEG, intracellular potential, ...) from naturalistic stimuli with PyTorch models. It bundles:

  • Datasets. A growing zoo of publicly available recordings (auditory cortex, midbrain, songbird auditory pallium, EEG, ...) behind a single NeuralDataset API with consistent NaN-sentinel handling for missing trials, optional download=True auto-download, and built-in selection / concatenation utilities.
  • Models. A unified four-slot template (wav2spec → prefiltering → core → readout) with reference implementations of widely used encoders (Linear, 2D-CNN, StateNet, DNet, Transformer, NRF) and the AdapTrans module of ON/OFF auditory adaptation.
  • Pretrained checkpoints published on the Hugging Face Hub.
  • Metrics. NaN-aware Pearson, FVE, Schoppe-normalized correlation, signal/noise power, coherence — all functional and torch.compile-friendly.
  • Training utility. A thin opt-in Fitter (~150 lines) for early-stopping + best-checkpoint training, on top of the canonical PyTorch loop.

📖 Full documentation: deepstrf.readthedocs.io

placeholder.png


⚡ Installation

deepSTRF requires Python ≥ 3.10. It is not yet on PyPI; install from source:

git clone https://github.com/urancon/deepSTRF
cd deepSTRF
pip install -e ".[dev]"      # or `pip install -e .` for runtime only

Optional extras: [docs], [allen] (Allen Brain Observatory tooling), [s4] (CUDA kernels for S4 layers), [eeg] (MNE for .fif parsing).

See the Installation guide for conda recipes and troubleshooting.


🚀 Quickstart

Load a published checkpoint, score it on the canonical NS1 ferret-A1 dataset:

from torch.utils.data import DataLoader
from deepSTRF.datasets.audio.ns1 import NS1Dataset
from deepSTRF.models.audio import StateNet
from deepSTRF.metrics import corrcoef, normalized_corrcoef
from deepSTRF.utils.data import neural_collate

# 1) Load a dataset (auto-downloads to a local cache the first time).
ds = NS1Dataset(download=True, dt_ms=5)
loader = DataLoader(ds, batch_size=8, collate_fn=neural_collate)

# 2) Load a pretrained model from the Hugging Face Hub.
model = StateNet.from_pretrained("urancon/deepSTRF-statenet-gru-ns1").eval()

# 3) Score it. Each batch is a dict: 'stims', 'responses', 'valid_mask', 'stim_meta'.
batch     = next(iter(loader))
responses = batch['responses']
pred      = model(batch['stims'])                           # (B, N, R=1, T)
psth      = responses.nanmean(dim=2, keepdim=True)
cc        = corrcoef(pred, psth, reduction='mean')
cc_norm   = normalized_corrcoef(pred, responses, method='schoppe', reduction='mean')
print(f"CCraw = {cc:.3f}   CCnorm = {cc_norm:.3f}")

🤖 Train a model

The opt-in Fitter wraps the canonical training loop (loss + early stop + best-checkpoint selection) in ~10 lines of user code:

from torch.optim import Adam
from deepSTRF.training import Fitter
from deepSTRF.metrics import mse_loss, normalized_corrcoef

fitter = Fitter(
    model=model,
    optimizer=Adam(model.parameters(), lr=1e-3),
    train_loader=train_loader,
    val_loader=val_loader,
    loss_fn=mse_loss,
    val_metrics={
        'cc_norm': lambda pred, resp: normalized_corrcoef(pred, resp, method='schoppe', reduction='mean'),
    },
)
fitter.fit(num_epochs=50)

For custom loops (mixed precision, multi-GPU, curricula, ...), the canonical three-line PyTorch loop documented in metrics_paradigm.md §7 stays a one-liner thanks to the metrics API. See the Fitter docs for hooks and design rationale.


📓 Tutorials

Runnable notebooks live under examples/ — each opens in Colab in one click.

Notebook Focus
crcns_aa_tutorial.ipynb Start here. Load CRCNS AA1 / AA2 zebra finch data end-to-end.
explore_nat4.ipynb Inspect NAT4 ferret A1 / PEG recordings.
dataset_concatenation.ipynb Mix multiple datasets behind one DataLoader.
fit_ns1_statenet.ipynb Fit StateNet on NS1 from scratch.
load_pretrained_statenet_ns1.ipynb Reuse the published HF Hub checkpoint.
alice_eeg_tutorial.ipynb EEG (Brodbeck 2023, "Alice").
le_2025_baseline.ipynb Zebra finch responses to occluded conspecific song.
strf_parameterizations_ns1.ipynb Parametric Gaussian-mixture STRFs.
strf_gradmap_aa2.ipynb Gradient-attribution receptive fields on AA2.
adaptrans_transformer_aa1.ipynb AdapTrans + Transformer on AA1 Field L.
espejo_nat_nrf.ipynb Network Receptive Field model on Espejo ferret A1.

🏁 Benchmark

Current top model on each dataset. Want to claim the podium? Open a PR with a ready-to-deploy PyTorch class so others can reproduce.

Dataset Model Remarks Params / nrn CCraw / CCnorm [%] Paper
NS1 StateNet GRU, pop 30,465 55.6 / 75.1 Rançon et al.
NAT4 A1 StateNet LSTM, pop 40,271 46.6 / 65.1 Rançon et al.
NAT4 PEG Transformer pop 28,437 39.7 / 55.5 Rançon et al.
AA1 Field L StateNet GRU, pop 24,900 / 71.0 Rançon et al.
AA1 MLd StateNet Mamba, pop 32,334 / 73.4 Rançon et al.

Note. The three CRCNS AC1 datasets (Wehr, Asari A1, Asari MGB) are single-unit fitting only and yield very different results depending on response preprocessing (detrending, spikes vs. raw potential, ...). Their benchmark will be reported separately.


📚 Datasets included

deepSTRF wraps publicly available recordings — please cite the original authors when you use them. See each dataset page on the docs for details and download instructions.


🚧 Status — audio-first

The first release of deepSTRF focuses on auditory datasets and models. A working video API is on the roadmap but is not yet shipped on develop — the deepSTRF.datasets.video and deepSTRF.models.video namespaces currently expose only their base classes (VideoNeuralDataset, VideoEncodingModel). Earlier draft loaders (Allen Ophys / Ecephys, CRCNS PVC1 / PVC11 / MT1 / MT2 / VIM2, MICrONS, UW Neural Data Challenge) live on the archive/video-api-v0 branch and will be revived once rewritten against the modernized base class.


💡 Contributing

Pull requests are welcome — most useful drops are new datasets, new model backbones, and pretrained checkpoints. Please open an issue first so we can scope it together (most importantly to confirm the dataset's license allows redistribution).


📖 Citation

If deepSTRF is useful for your work, please cite the relevant paper:

@article{rancon2024pcb,
    title   = {A general model unifying the adaptive, transient and sustained properties of ON and OFF auditory neural responses},
    author  = {Rançon, Ulysse and Masquelier, Timothée and Cottereau, Benoit R.},
    journal = {PLOS Computational Biology},
    year    = {2024},
    volume  = {20},
    number  = {8},
    pages   = {1--32},
    doi     = {10.1371/journal.pcbi.1012288},
}

@article{rancon2025commbio,
    title   = {Temporal recurrence as a general mechanism to explain neural responses in the auditory system},
    author  = {Rançon, Ulysse and Masquelier, Timothée and Cottereau, Benoit R.},
    journal = {Communications Biology},
    year    = {2025},
    volume  = {8},
    number  = {1},
    pages   = {1456},
    doi     = {10.1038/s42003-025-08858-3},
}

A running list of papers that build on deepSTRF lives on the Publications docs page. PRs welcome to add yours.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

deepstrf-0.1.0.tar.gz (357.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

deepstrf-0.1.0-py3-none-any.whl (317.2 kB view details)

Uploaded Python 3

File details

Details for the file deepstrf-0.1.0.tar.gz.

File metadata

  • Download URL: deepstrf-0.1.0.tar.gz
  • Upload date:
  • Size: 357.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for deepstrf-0.1.0.tar.gz
Algorithm Hash digest
SHA256 a79f2e26f555c328359238fb454feba7fa8ec263f04f549f0a670c26de50581f
MD5 682088a12d68c8262857c5340dce7057
BLAKE2b-256 b33688951700598ab60f6ad3eeed8d59943413a35e2552c082a602ed3abd7463

See more details on using hashes here.

Provenance

The following attestation bundles were made for deepstrf-0.1.0.tar.gz:

Publisher: publish.yml on urancon/deepSTRF

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file deepstrf-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: deepstrf-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 317.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for deepstrf-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f9ec98c072256c9669d3c923f57b899dcad5b4269a006629f5bb5dc6d603df28
MD5 3e0d74b9b204151fc0357c466be81586
BLAKE2b-256 dd603500dc8242f82d8de82490d7eb67c99eab0b607ae9d405fab8fa5eebf7a8

See more details on using hashes here.

Provenance

The following attestation bundles were made for deepstrf-0.1.0-py3-none-any.whl:

Publisher: publish.yml on urancon/deepSTRF

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page