Skip to main content

Trial-aware contrastive learning wrapper for CEBRA

Project description

TrialCEBRA

PyPI Tests
[English | 中文]

Trial-aware contrastive learning for CEBRA. Pass 3-D epoch-format data (ntrial, ntime, nneuro) — trial boundaries are respected automatically.


Installation

Step 1 — Install PyTorch from pytorch.org.

Step 2

pip install TrialCEBRA

Quick Start

import numpy as np
from trial_cebra import TrialCEBRA

X = np.random.randn(40, 50, 64).astype(np.float32)   # (ntrial, ntime, nneuro)
y = np.random.randn(40, 16).astype(np.float32)        # (ntrial, stim_dim)

model = TrialCEBRA(
    model_architecture = "offset10-model",
    conditional        = "delta",
    time_offsets       = 5,
    delta              = 0.3,
    output_dimension   = 3,
    max_iterations     = 1000,
    batch_size         = 512,
)

model.fit(X, y)
emb = model.transform(X)        # (ntrial, ntime, 3)  — shape preserved

2-D flat input (N, nneuro) falls back to native CEBRA behavior unchanged.


Conditionals

conditional Trial selection Within-trial y shape
"time" Uniform random ±time_offsets window not required
"delta" Gaussian similarity on y Uniform (free) (ntrial, nd) or (ntrial, ntime, nd)
"time_delta" Joint argmin across trials ±time_offsets window (ntrial, ntime, nd)

Pass a discrete integer label (e.g. y_disc of dtype int64) alongside a continuous label to enable class-conditional trial selection for "delta".


Key Parameters

TrialCEBRA(
    conditional            = "delta",    # "time" | "delta" | "time_delta"
    time_offsets           = 10,         # half-width of within-trial time window
    delta                  = 0.1,        # Gaussian noise std for trial similarity
    sample_fix_trial       = False,      # True: fix trial pairing at init
    sample_exclude_intrial = True,       # True: positives always from a different trial
    sample_prior           = "balanced", # "balanced" (default) or "uniform"
    output_dimension       = 3,
    # ... all other cebra.CEBRA kwargs accepted
)

After fit, the distribution is accessible at model.distribution_.


Transform

transform() preserves input dimensionality:

emb = model.transform(X)          # (ntrial, ntime, 3) if X is (ntrial, ntime, nneuro)
emb = model.transform(X_flat)     # (N, 3)             if X_flat is (N, nneuro)
emb = model.transform_epochs(X)   # strict 3-D variant — raises if X.ndim != 3

Metrics

All metric methods accept epoch-format (ntrial, ntime, nneuro) data directly:

loss = model.infonce_loss(X, y)
gof  = model.goodness_of_fit_score(X, y)
hist = model.goodness_of_fit_history()       # training curve, no X needed

# Consistency score: accepts 3-D embedding lists
emb1 = model.transform(X1)   # (ntrial, ntime, 3)
emb2 = model.transform(X2)
scores, pairs, ids = TrialCEBRA.consistency_score(
    [emb1, emb2], between="runs"
)
# between-datasets: pass labels=(ntrial, ntime) or (ntrial*ntime,)
scores, pairs, ids = TrialCEBRA.consistency_score(
    [emb1, emb2],
    between="datasets",
    labels=[y1, y2],
    dataset_ids=["mouse1", "mouse2"],
)

Decoders

CEBRA decoders (KNNDecoder, L1LinearRegressor) are standalone sklearn estimators that expect 2-D input. Flatten the embedding first:

import cebra

emb      = model.transform(X)                  # (ntrial, ntime, 3)
emb_flat = emb.reshape(-1, emb.shape[-1])       # (ntrial*ntime, 3)
y_flat   = y.reshape(-1)                        # (ntrial*ntime,)

decoder = cebra.KNNDecoder()
decoder.fit(emb_flat, y_flat)
score = decoder.score(emb_flat, y_flat)

Multi-session

Pass X as a list of epoch arrays (one per session):

X = [
    np.random.randn(30, 100, 64).astype(np.float32),   # session 0
    np.random.randn(25,  80, 48).astype(np.float32),   # session 1
]
y_cont = [np.random.randn(30, 100, 16).astype(np.float32),
          np.random.randn(25,  80, 16).astype(np.float32)]
y_disc = [np.zeros((30, 100), dtype=np.int64),
          np.zeros((25,  80), dtype=np.int64)]

model = TrialCEBRA(conditional="delta", output_dimension=3, max_iterations=1000)
model.fit(X, y_disc, y_cont)

"delta" and "time_delta" are supported for multi-session; "time" raises NotImplementedError.


Contributing

uv sync --dev
uv run pre-commit install --hook-type pre-commit --hook-type pre-push
uv run pytest tests/ -v

Release: git tag vX.X.X && git push origin vX.X.X.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trialcebra-0.0.5.tar.gz (3.1 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

trialcebra-0.0.5-py3-none-any.whl (33.7 kB view details)

Uploaded Python 3

File details

Details for the file trialcebra-0.0.5.tar.gz.

File metadata

  • Download URL: trialcebra-0.0.5.tar.gz
  • Upload date:
  • Size: 3.1 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.11.8 {"installer":{"name":"uv","version":"0.11.8","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for trialcebra-0.0.5.tar.gz
Algorithm Hash digest
SHA256 3355e80853719a4db38c8948a9d17c3389752d0534d67bfeb1436a68366a9938
MD5 3bbd4dc15307f8db8c044800bd21cfa4
BLAKE2b-256 9bc9ff493e646cc6a7faac0eb928a9accd2e7662c7030c2e1c4c46b9c5278265

See more details on using hashes here.

File details

Details for the file trialcebra-0.0.5-py3-none-any.whl.

File metadata

  • Download URL: trialcebra-0.0.5-py3-none-any.whl
  • Upload date:
  • Size: 33.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.11.8 {"installer":{"name":"uv","version":"0.11.8","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for trialcebra-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 39e34b5b46c4922cb188c0a12af9810b42dd5e1beb455b57498df1725ca6c030
MD5 1cc3afeed1a1381a4696bb5ff5d8ba08
BLAKE2b-256 94cb6ac5aa2dc1d79bf85e18858002747084c260ec40748a6b2c506f268a948a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page