A package for deep learning models for neuroscience
Project description
torch_brain
Documentation | Join our Discord community
[!NOTE] We have merged
temporaldataandbrainsetsintotorch_brain. If you are migrating from v0.1.x, please see this migration guide.
torch_brain is an end-to-end framework for building deep learning models
and training pipelines for neuroscience. It pairs a lightweight, time-based data
format (plus tools to preprocess existing neural datasets into it) with
PyTorch-compatible building blocks: datasets, samplers, nn.Modules, and
models.
Features
- Lazy, on-demand data loading that reads only the time-slices and attributes you request
- Advanced samplers for arbitrary on-the-fly slicing of recordings
- Multi-recording training across heterogeneous datasets
- Support for arbitrary neural and behavioral modalities
- Flexible collation strategies, including chaining and padding
Installation
torch_brain requires Python >= 3.10. To install a stable release:
pip install torch torch_brain
[!CAUTION] Until we release
v0.2.0on PyPI, you will have to install from GitHub itself. See the releases page for updates on releases.pip install torch git+https://github.com/neuro-galaxy/torch_brain
[!TIP] If you only need
torch_brain.dataand the data-preparation pipelines, you can skip installingtorch.
Latest development version:
Install the latest (unstable) development version via the main branch:
pip install git+https://github.com/neuro-galaxy/torch_brain
The data format
A recording is a Data object holding heterogeneous, time-aware modalities:
regularly-sampled signals (LFP, EEG, etc.), irregular event streams (spikes),
interval annotations (trials), and plain arrays.
import numpy as np
from torch_brain.data import Data, IrregularTimeSeries, RegularTimeSeries, Interval
data = Data(
spikes=IrregularTimeSeries( # event stream
timestamps=[0.1, 0.2, 0.3, 2.1, 2.2, 2.3],
unit_index=[0, 0, 1, 0, 1, 2],
domain="auto",
),
lfp=RegularTimeSeries(raw=np.zeros((1000, 3)), sampling_rate=250.0), # 4s @ 250Hz
trials=Interval(start=[0, 1, 2], end=[1, 2, 3]), # annotations
domain=Interval(0.0, 4.0),
)
The point of the format is that slicing is time-based and lazy: Every modality is sliced consistently, regardless of their different sampling rates, and the data is lazily read from disk so only the requested window and attributes are loaded.
window = data.slice(1.0, 3.0)
# spikes -> the 3 events in [1, 3) lfp -> 500 samples trials -> 2 trials
This is why a torch_brain Dataset is indexed by time, not by integer (see below).
Training pipelines
torch_brain leans on the standard PyTorch training loop, and most of its job is
to handle the data side. You define a Dataset (built on the time-slicing
above) and a Sampler that decides which slices become samples. The
DataLoader, model, and loop are ordinary PyTorch.
import torch
from torch.utils.data import DataLoader
from torch_brain.datasets import PeiPandarinathNLB2021, DatasetIndex
from torch_brain.samplers import TrialSampler
from torch_brain.utils import bin_spikes
# torch_brain ships loaders for many public datasets.
# Subclass one to define the two things specific to your task:
class MyDataset(PeiPandarinathNLB2021):
# 1. WHICH windows count as samples (here, one per behavioral trial).
def get_sampling_intervals(self):
sampling_intervals = {}
for rid in self.recording_ids:
sampling_intervals[rid] = self.get_recording(rid).trials
return sampling_intervals
# 2. HOW one window becomes tensors.
def __getitem__(self, index: DatasetIndex):
# `index` is a DatasetIndex(recording_id, start, end) handed in by the sampler;
data = super().__getitem__(index)
# super().__getitem__(...) returns that slice with
# every modality (.spikes, .hand.vel, ...) lazily cropped.
# Only attributes actually accessed will be loaded into memory from disk.
X = bin_spikes(data.spikes, num_units=len(data.units), bin_size=0.05)
Y = data.hand.vel
return torch.from_numpy(X).float(), torch.from_numpy(Y).float()
dataset = MyDataset(root="data/processed", recording_ids=["jenkins_maze_train"])
# The sampler turns those intervals into per-sample DatasetIndex objects.
sampler = TrialSampler(sampling_intervals=dataset.get_sampling_intervals(), shuffle=True)
loader = DataLoader(dataset, sampler=sampler, batch_size=8)
# From here on it's plain PyTorch
for X, Y in loader:
pred = model(X)
loss = loss_fn(pred, Y)
...
The key idea: unlike a standard PyTorch Dataset indexed by integers, a
torch_brain Dataset is indexed by time-slices, and loads data lazily, so
only the slice you ask for is read from disk. A Sampler decides what to
load, the Dataset decides how, and everything downstream stays vanilla
PyTorch.
See examples/ for simple and readable training implementations.
Contributing
Contributions are welcome! Get started with:
pip install -e ".[dev]" # editable install with dev dependencies
pre-commit install # formatting & lint hooks
pytest # run the test suite
See CONTRIBUTING.md for the full workflow and code-style guidelines.
Building the documentation
pip install -e ".[dev,docs]"
cd docs && make clean html
The built docs are placed in docs/build/html.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_brain-0.2.0a0.tar.gz.
File metadata
- Download URL: torch_brain-0.2.0a0.tar.gz
- Upload date:
- Size: 3.8 MB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b7a1140ffaed0887614242495faab6ec447dfce1573c50df6c0075619c2532ab
|
|
| MD5 |
814722d6486c76487becc8994d2ee301
|
|
| BLAKE2b-256 |
eccee3c6d6c67b9a09b8c9a98a73b48387df05834005128e21706e25f30ebd51
|
Provenance
The following attestation bundles were made for torch_brain-0.2.0a0.tar.gz:
Publisher:
publish.yml on neuro-galaxy/torch_brain
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
torch_brain-0.2.0a0.tar.gz -
Subject digest:
b7a1140ffaed0887614242495faab6ec447dfce1573c50df6c0075619c2532ab - Sigstore transparency entry: 1871844851
- Sigstore integration time:
-
Permalink:
neuro-galaxy/torch_brain@0456ce008d0e04a01240befac4002fe4ccf03ecd -
Branch / Tag:
refs/tags/v0.2.0a0 - Owner: https://github.com/neuro-galaxy
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@0456ce008d0e04a01240befac4002fe4ccf03ecd -
Trigger Event:
push
-
Statement type:
File details
Details for the file torch_brain-0.2.0a0-py3-none-any.whl.
File metadata
- Download URL: torch_brain-0.2.0a0-py3-none-any.whl
- Upload date:
- Size: 234.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
321f4cec9026964f74b8727a9b480585c9b39460946b61c0cb68b075a55759a5
|
|
| MD5 |
8c83b19e617ce95c7bd4569bb84ce289
|
|
| BLAKE2b-256 |
448598f71009e34fe6a2897ab0900d4eb9930edc0dc2941bdb2011fd2b2b3940
|
Provenance
The following attestation bundles were made for torch_brain-0.2.0a0-py3-none-any.whl:
Publisher:
publish.yml on neuro-galaxy/torch_brain
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
torch_brain-0.2.0a0-py3-none-any.whl -
Subject digest:
321f4cec9026964f74b8727a9b480585c9b39460946b61c0cb68b075a55759a5 - Sigstore transparency entry: 1871844938
- Sigstore integration time:
-
Permalink:
neuro-galaxy/torch_brain@0456ce008d0e04a01240befac4002fe4ccf03ecd -
Branch / Tag:
refs/tags/v0.2.0a0 - Owner: https://github.com/neuro-galaxy
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@0456ce008d0e04a01240befac4002fe4ccf03ecd -
Trigger Event:
push
-
Statement type: