Skip to main content

A package for deep learning models for neuroscience

Project description

torch_brain

Documentation | Join our Discord community

PyPI version Documentation Status Tests Linting Discord

[!NOTE] We have merged temporaldata and brainsets into torch_brain. If you are migrating from v0.1.x, please see this migration guide.

torch_brain is an end-to-end framework for building deep learning models and training pipelines for neuroscience. It pairs a lightweight, time-based data format (plus tools to preprocess existing neural datasets into it) with PyTorch-compatible building blocks: datasets, samplers, nn.Modules, and models.

Features

  • Lazy, on-demand data loading that reads only the time-slices and attributes you request
  • Advanced samplers for arbitrary on-the-fly slicing of recordings
  • Multi-recording training across heterogeneous datasets
  • Support for arbitrary neural and behavioral modalities
  • Flexible collation strategies, including chaining and padding

Installation

torch_brain requires Python >= 3.10. To install a stable release:

pip install torch torch_brain

[!CAUTION] Until we release v0.2.0 on PyPI, you will have to install from GitHub itself. See the releases page for updates on releases.

pip install torch git+https://github.com/neuro-galaxy/torch_brain

[!TIP] If you only need torch_brain.data and the data-preparation pipelines, you can skip installing torch.

Latest development version:

Install the latest (unstable) development version via the main branch:

pip install git+https://github.com/neuro-galaxy/torch_brain

The data format

A recording is a Data object holding heterogeneous, time-aware modalities: regularly-sampled signals (LFP, EEG, etc.), irregular event streams (spikes), interval annotations (trials), and plain arrays.

import numpy as np
from torch_brain.data import Data, IrregularTimeSeries, RegularTimeSeries, Interval

data = Data(
    spikes=IrregularTimeSeries(                       # event stream
        timestamps=[0.1, 0.2, 0.3, 2.1, 2.2, 2.3],
        unit_index=[0, 0, 1, 0, 1, 2],
        domain="auto",
    ),
    lfp=RegularTimeSeries(raw=np.zeros((1000, 3)), sampling_rate=250.0),  # 4s @ 250Hz
    trials=Interval(start=[0, 1, 2], end=[1, 2, 3]),  # annotations
    domain=Interval(0.0, 4.0),
)

The point of the format is that slicing is time-based and lazy: Every modality is sliced consistently, regardless of their different sampling rates, and the data is lazily read from disk so only the requested window and attributes are loaded.

window = data.slice(1.0, 3.0)
# spikes -> the 3 events in [1, 3)   lfp -> 500 samples   trials -> 2 trials

This is why a torch_brain Dataset is indexed by time, not by integer (see below).

Training pipelines

torch_brain leans on the standard PyTorch training loop, and most of its job is to handle the data side. You define a Dataset (built on the time-slicing above) and a Sampler that decides which slices become samples. The DataLoader, model, and loop are ordinary PyTorch.

import torch
from torch.utils.data import DataLoader
from torch_brain.datasets import PeiPandarinathNLB2021, DatasetIndex
from torch_brain.samplers import TrialSampler
from torch_brain.utils import bin_spikes

# torch_brain ships loaders for many public datasets.
# Subclass one to define the two things specific to your task:
class MyDataset(PeiPandarinathNLB2021):
    # 1. WHICH windows count as samples (here, one per behavioral trial).
    def get_sampling_intervals(self):
        sampling_intervals = {}
        for rid in self.recording_ids:
            sampling_intervals[rid] = self.get_recording(rid).trials
        return sampling_intervals

    # 2. HOW one window becomes tensors.
    def __getitem__(self, index: DatasetIndex):
        # `index` is a DatasetIndex(recording_id, start, end) handed in by the sampler;

        data = super().__getitem__(index)
        # super().__getitem__(...) returns that slice with
        # every modality (.spikes, .hand.vel, ...) lazily cropped.

        # Only attributes actually accessed will be loaded into memory from disk.
        X = bin_spikes(data.spikes, num_units=len(data.units), bin_size=0.05)
        Y = data.hand.vel
        return torch.from_numpy(X).float(), torch.from_numpy(Y).float()

dataset = MyDataset(root="data/processed", recording_ids=["jenkins_maze_train"])

# The sampler turns those intervals into per-sample DatasetIndex objects.
sampler = TrialSampler(sampling_intervals=dataset.get_sampling_intervals(), shuffle=True)
loader = DataLoader(dataset, sampler=sampler, batch_size=8)

# From here on it's plain PyTorch
for X, Y in loader:
    pred = model(X)
    loss = loss_fn(pred, Y)
    ...

The key idea: unlike a standard PyTorch Dataset indexed by integers, a torch_brain Dataset is indexed by time-slices, and loads data lazily, so only the slice you ask for is read from disk. A Sampler decides what to load, the Dataset decides how, and everything downstream stays vanilla PyTorch.

See examples/ for simple and readable training implementations.

Contributing

Contributions are welcome! Get started with:

pip install -e ".[dev]"   # editable install with dev dependencies
pre-commit install        # formatting & lint hooks
pytest                    # run the test suite

See CONTRIBUTING.md for the full workflow and code-style guidelines.

Building the documentation

pip install -e ".[dev,docs]"
cd docs && make clean html

The built docs are placed in docs/build/html.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_brain-0.2.0a0.tar.gz (3.8 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_brain-0.2.0a0-py3-none-any.whl (234.6 kB view details)

Uploaded Python 3

File details

Details for the file torch_brain-0.2.0a0.tar.gz.

File metadata

  • Download URL: torch_brain-0.2.0a0.tar.gz
  • Upload date:
  • Size: 3.8 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for torch_brain-0.2.0a0.tar.gz
Algorithm Hash digest
SHA256 b7a1140ffaed0887614242495faab6ec447dfce1573c50df6c0075619c2532ab
MD5 814722d6486c76487becc8994d2ee301
BLAKE2b-256 eccee3c6d6c67b9a09b8c9a98a73b48387df05834005128e21706e25f30ebd51

See more details on using hashes here.

Provenance

The following attestation bundles were made for torch_brain-0.2.0a0.tar.gz:

Publisher: publish.yml on neuro-galaxy/torch_brain

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file torch_brain-0.2.0a0-py3-none-any.whl.

File metadata

  • Download URL: torch_brain-0.2.0a0-py3-none-any.whl
  • Upload date:
  • Size: 234.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for torch_brain-0.2.0a0-py3-none-any.whl
Algorithm Hash digest
SHA256 321f4cec9026964f74b8727a9b480585c9b39460946b61c0cb68b075a55759a5
MD5 8c83b19e617ce95c7bd4569bb84ce289
BLAKE2b-256 448598f71009e34fe6a2897ab0900d4eb9930edc0dc2941bdb2011fd2b2b3940

See more details on using hashes here.

Provenance

The following attestation bundles were made for torch_brain-0.2.0a0-py3-none-any.whl:

Publisher: publish.yml on neuro-galaxy/torch_brain

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page