Skip to main content

Trial-aware contrastive learning wrapper for CEBRA

Project description

TrialCEBRA

PyPI Tests
[English | 中文]

Trial-aware contrastive learning for CEBRA. Pass 3-D epoch-format data (ntrial, ntime, nneuro) — trial boundaries are respected automatically.


Installation

Step 1 — Install PyTorch from pytorch.org.

Step 2

pip install TrialCEBRA

Quick Start

import numpy as np
from trial_cebra import TrialCEBRA

X = np.random.randn(40, 50, 64).astype(np.float32)   # (ntrial, ntime, nneuro)
y = np.random.randn(40, 16).astype(np.float32)        # (ntrial, stim_dim)

model = TrialCEBRA(
    model_architecture = "offset10-model",
    conditional        = "delta",
    time_offsets       = 5,
    delta              = 0.3,
    output_dimension   = 3,
    max_iterations     = 1000,
    batch_size         = 512,
)

model.fit(X, y)
emb = model.transform(X)        # (ntrial, ntime, 3)  — shape preserved

2-D flat input (N, nneuro) falls back to native CEBRA behavior unchanged.


Conditionals

conditional Trial selection Within-trial y shape
"time" Uniform random ±time_offsets window not required
"delta" Gaussian similarity on y Uniform (free) (ntrial, nd) or (ntrial, ntime, nd)
"time_delta" Joint argmin across trials ±time_offsets window (ntrial, ntime, nd)

Pass a discrete integer label (e.g. y_disc of dtype int64) alongside a continuous label to enable class-conditional trial selection for "delta".


Key Parameters

TrialCEBRA(
    conditional            = "delta",    # "time" | "delta" | "time_delta"
    time_offsets           = 10,         # half-width of within-trial time window
    delta                  = 0.1,        # Gaussian noise std for trial similarity
    sample_fix_trial       = False,      # True: fix trial pairing at init
    sample_exclude_intrial = True,       # True: positives always from a different trial
    sample_prior           = "balanced", # "balanced" (default) or "uniform"
    output_dimension       = 3,
    # ... all other cebra.CEBRA kwargs accepted
)

After fit, the distribution is accessible at model.distribution_.


Transform

transform() preserves input dimensionality:

emb = model.transform(X)          # (ntrial, ntime, 3) if X is (ntrial, ntime, nneuro)
emb = model.transform(X_flat)     # (N, 3)             if X_flat is (N, nneuro)
emb = model.transform_epochs(X)   # strict 3-D variant — raises if X.ndim != 3

Metrics

All metric methods accept epoch-format (ntrial, ntime, nneuro) data directly:

loss = model.infonce_loss(X, y)
gof  = model.goodness_of_fit_score(X, y)
hist = model.goodness_of_fit_history()       # training curve, no X needed

# Consistency score: accepts 3-D embedding lists
emb1 = model.transform(X1)   # (ntrial, ntime, 3)
emb2 = model.transform(X2)
scores, pairs, ids = TrialCEBRA.consistency_score(
    [emb1, emb2], between="runs"
)
# between-datasets: pass labels=(ntrial, ntime) or (ntrial*ntime,)
scores, pairs, ids = TrialCEBRA.consistency_score(
    [emb1, emb2],
    between="datasets",
    labels=[y1, y2],
    dataset_ids=["mouse1", "mouse2"],
)

Decoders

CEBRA decoders (KNNDecoder, L1LinearRegressor) are standalone sklearn estimators that expect 2-D input. Flatten the embedding first:

import cebra

emb      = model.transform(X)                  # (ntrial, ntime, 3)
emb_flat = emb.reshape(-1, emb.shape[-1])       # (ntrial*ntime, 3)
y_flat   = y.reshape(-1)                        # (ntrial*ntime,)

decoder = cebra.KNNDecoder()
decoder.fit(emb_flat, y_flat)
score = decoder.score(emb_flat, y_flat)

Multi-session

Pass X as a list of epoch arrays (one per session):

X = [
    np.random.randn(30, 100, 64).astype(np.float32),   # session 0
    np.random.randn(25,  80, 48).astype(np.float32),   # session 1
]
y_cont = [np.random.randn(30, 100, 16).astype(np.float32),
          np.random.randn(25,  80, 16).astype(np.float32)]
y_disc = [np.zeros((30, 100), dtype=np.int64),
          np.zeros((25,  80), dtype=np.int64)]

model = TrialCEBRA(conditional="delta", output_dimension=3, max_iterations=1000)
model.fit(X, y_disc, y_cont)

"delta" and "time_delta" are supported for multi-session; "time" raises NotImplementedError.


Contributing

uv sync --dev
uv run pre-commit install --hook-type pre-commit --hook-type pre-push
uv run pytest tests/ -v

Release: git tag vX.X.X && git push origin vX.X.X.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trialcebra-0.0.5a0.tar.gz (3.1 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

trialcebra-0.0.5a0-py3-none-any.whl (33.8 kB view details)

Uploaded Python 3

File details

Details for the file trialcebra-0.0.5a0.tar.gz.

File metadata

  • Download URL: trialcebra-0.0.5a0.tar.gz
  • Upload date:
  • Size: 3.1 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.11.8 {"installer":{"name":"uv","version":"0.11.8","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for trialcebra-0.0.5a0.tar.gz
Algorithm Hash digest
SHA256 90758af319580d3ddf5033ac8447bfa541b9213d3dedba6726aba62d83e427bb
MD5 53d04f8287624956eba50029bb2b7eb0
BLAKE2b-256 519d17d3e19bc5c3f0a5a8bbd2f48017d8c9649a8001a4607db284699a71c0a0

See more details on using hashes here.

File details

Details for the file trialcebra-0.0.5a0-py3-none-any.whl.

File metadata

  • Download URL: trialcebra-0.0.5a0-py3-none-any.whl
  • Upload date:
  • Size: 33.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.11.8 {"installer":{"name":"uv","version":"0.11.8","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for trialcebra-0.0.5a0-py3-none-any.whl
Algorithm Hash digest
SHA256 8564efd5bca379de7c7811ab7993d60ce6e0b7332184006ef1f1d70325b8fc84
MD5 df4622f1a973262a9ac76fa827b85171
BLAKE2b-256 dadda226e5bd89fef83c936d8b2be2463277d076419a51e715cb26e7f1e73097

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page