Skip to main content

Trial-aware contrastive learning wrapper for CEBRA

Project description

TrialCEBRA

PyPI Tests
[English | 中文]

Trial-aware contrastive learning for CEBRA. Pass 3-D epoch-format data (ntrial, ntime, nneuro) — trial boundaries are respected automatically.


Installation

Step 1 — Install PyTorch from pytorch.org.

Step 2

pip install TrialCEBRA

Quick Start

import numpy as np
from trial_cebra import TrialCEBRA

X = np.random.randn(40, 50, 64).astype(np.float32)   # (ntrial, ntime, nneuro)
y = np.random.randn(40, 16).astype(np.float32)        # (ntrial, stim_dim)

model = TrialCEBRA(
    model_architecture = "offset10-model",
    conditional        = "delta",
    time_offsets       = 5,
    delta              = 0.3,
    output_dimension   = 3,
    max_iterations     = 1000,
    batch_size         = 512,
)

model.fit(X, y)
emb = model.transform(X)        # (ntrial, ntime, 3)  — shape preserved

2-D flat input (N, nneuro) falls back to native CEBRA behavior unchanged.


Conditionals

conditional Trial selection Within-trial y shape
"time" Uniform random ±time_offsets window not required
"delta" Gaussian similarity on y Uniform (free) (ntrial, nd) or (ntrial, ntime, nd)
"time_delta" Joint argmin across trials ±time_offsets window (ntrial, ntime, nd)

Pass a discrete integer label (e.g. y_disc of dtype int64) alongside a continuous label to enable class-conditional trial selection for "delta".


Key Parameters

TrialCEBRA(
    conditional            = "delta",    # "time" | "delta" | "time_delta"
    time_offsets           = 10,         # half-width of within-trial time window
    delta                  = 0.1,        # Gaussian noise std for trial similarity
    sample_fix_trial       = False,      # True: fix trial pairing at init
    sample_exclude_intrial = True,       # True: positives always from a different trial
    sample_prior           = "balanced", # "balanced" (default) or "uniform"
    output_dimension       = 3,
    # ... all other cebra.CEBRA kwargs accepted
)

After fit, the distribution is accessible at model.distribution_.


Transform

transform() preserves input dimensionality:

emb = model.transform(X)          # (ntrial, ntime, 3) if X is (ntrial, ntime, nneuro)
emb = model.transform(X_flat)     # (N, 3)             if X_flat is (N, nneuro)
emb = model.transform_epochs(X)   # strict 3-D variant — raises if X.ndim != 3

Metrics

All metric methods accept epoch-format (ntrial, ntime, nneuro) data directly:

loss = model.infonce_loss(X, y)
gof  = model.goodness_of_fit_score(X, y)
hist = model.goodness_of_fit_history()       # training curve, no X needed

# Consistency score: accepts 3-D embedding lists
emb1 = model.transform(X1)   # (ntrial, ntime, 3)
emb2 = model.transform(X2)
scores, pairs, ids = TrialCEBRA.consistency_score(
    [emb1, emb2], between="runs"
)
# between-datasets: pass labels=(ntrial, ntime) or (ntrial*ntime,)
scores, pairs, ids = TrialCEBRA.consistency_score(
    [emb1, emb2],
    between="datasets",
    labels=[y1, y2],
    dataset_ids=["mouse1", "mouse2"],
)

Decoders

CEBRA decoders (KNNDecoder, L1LinearRegressor) are standalone sklearn estimators that expect 2-D input. Flatten the embedding first:

import cebra

emb      = model.transform(X)                  # (ntrial, ntime, 3)
emb_flat = emb.reshape(-1, emb.shape[-1])       # (ntrial*ntime, 3)
y_flat   = y.reshape(-1)                        # (ntrial*ntime,)

decoder = cebra.KNNDecoder()
decoder.fit(emb_flat, y_flat)
score = decoder.score(emb_flat, y_flat)

Multi-session

Pass X as a list of epoch arrays (one per session):

X = [
    np.random.randn(30, 100, 64).astype(np.float32),   # session 0
    np.random.randn(25,  80, 48).astype(np.float32),   # session 1
]
y_cont = [np.random.randn(30, 100, 16).astype(np.float32),
          np.random.randn(25,  80, 16).astype(np.float32)]
y_disc = [np.zeros((30, 100), dtype=np.int64),
          np.zeros((25,  80), dtype=np.int64)]

model = TrialCEBRA(conditional="delta", output_dimension=3, max_iterations=1000)
model.fit(X, y_disc, y_cont)

"delta" and "time_delta" are supported for multi-session; "time" raises NotImplementedError.


Contributing

uv sync --dev
uv run pre-commit install --hook-type pre-commit --hook-type pre-push
uv run pytest tests/ -v

Release: git tag vX.X.X && git push origin vX.X.X.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trialcebra-0.0.6a0.tar.gz (3.1 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

trialcebra-0.0.6a0-py3-none-any.whl (34.9 kB view details)

Uploaded Python 3

File details

Details for the file trialcebra-0.0.6a0.tar.gz.

File metadata

  • Download URL: trialcebra-0.0.6a0.tar.gz
  • Upload date:
  • Size: 3.1 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.11.8 {"installer":{"name":"uv","version":"0.11.8","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for trialcebra-0.0.6a0.tar.gz
Algorithm Hash digest
SHA256 29fb0e5b6b44901dcd0f5c5411ee92fe68715b34becdbafe3ee9d11bf11b474f
MD5 e90fd1e669dcdfc638ddac866a6e780e
BLAKE2b-256 7352f1d568231e48c71982b4f5276e77dfdfa0eb53717e1c3e93adf307ba1dd7

See more details on using hashes here.

File details

Details for the file trialcebra-0.0.6a0-py3-none-any.whl.

File metadata

  • Download URL: trialcebra-0.0.6a0-py3-none-any.whl
  • Upload date:
  • Size: 34.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.11.8 {"installer":{"name":"uv","version":"0.11.8","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for trialcebra-0.0.6a0-py3-none-any.whl
Algorithm Hash digest
SHA256 486fa2797da29864e789b00438bf721bd5b625ab93a74c772c3e821941d744dd
MD5 6342ddd1a8525c99ba99b2e6b05c2a1b
BLAKE2b-256 e1a1845b1e588b5f4fa5029c03be726d02d0ba373a3c0d59190be31ad0e71a2f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page