Skip to main content

Trial-aware contrastive learning wrapper for CEBRA

Project description

TrialCEBRA

PyPI Tests
[English | 中文]

Trial-aware contrastive learning for CEBRA. Pass 3-D epoch-format data (ntrial, ntime, nneuro) — trial boundaries are respected automatically.


Installation

Step 1 — Install PyTorch from pytorch.org.

Step 2

pip install TrialCEBRA

Quick Start

import numpy as np
from trial_cebra import TrialCEBRA

X = np.random.randn(40, 50, 64).astype(np.float32)   # (ntrial, ntime, nneuro)
y = np.random.randn(40, 16).astype(np.float32)        # (ntrial, stim_dim)

model = TrialCEBRA(
    model_architecture = "offset10-model",
    conditional        = "delta",
    time_offsets       = 5,
    delta              = 0.3,
    output_dimension   = 3,
    max_iterations     = 1000,
    batch_size         = 512,
)

model.fit(X, y)
emb = model.transform(X)        # (ntrial, ntime, 3)  — shape preserved

2-D flat input (N, nneuro) falls back to native CEBRA behavior unchanged.


Conditionals

conditional Trial selection Within-trial y shape
"time" Uniform random ±time_offsets window not required
"delta" Gaussian similarity on y Uniform (free) (ntrial, nd) or (ntrial, ntime, nd)
"time_delta" Joint argmin across trials ±time_offsets window (ntrial, ntime, nd)

Pass a discrete integer label (e.g. y_disc of dtype int64) alongside a continuous label to enable class-conditional trial selection for "delta".


Key Parameters

TrialCEBRA(
    conditional            = "delta",    # "time" | "delta" | "time_delta"
    time_offsets           = 10,         # half-width of within-trial time window
    delta                  = 0.1,        # Gaussian noise std for trial similarity
    sample_fix_trial       = False,      # True: fix trial pairing at init
    sample_exclude_intrial = True,       # True: positives always from a different trial
    sample_prior           = "balanced", # "balanced" (default) or "uniform"
    output_dimension       = 3,
    # ... all other cebra.CEBRA kwargs accepted
)

After fit, the distribution is accessible at model.distribution_.


Transform

transform() preserves input dimensionality:

emb = model.transform(X)          # (ntrial, ntime, 3) if X is (ntrial, ntime, nneuro)
emb = model.transform(X_flat)     # (N, 3)             if X_flat is (N, nneuro)
emb = model.transform_epochs(X)   # strict 3-D variant — raises if X.ndim != 3

Metrics

All metric methods accept epoch-format (ntrial, ntime, nneuro) data directly:

loss = model.infonce_loss(X, y)
gof  = model.goodness_of_fit_score(X, y)
hist = model.goodness_of_fit_history()       # training curve, no X needed

# Consistency score: accepts 3-D embedding lists
emb1 = model.transform(X1)   # (ntrial, ntime, 3)
emb2 = model.transform(X2)
scores, pairs, ids = TrialCEBRA.consistency_score(
    [emb1, emb2], between="runs"
)
# between-datasets: pass labels=(ntrial, ntime) or (ntrial*ntime,)
scores, pairs, ids = TrialCEBRA.consistency_score(
    [emb1, emb2],
    between="datasets",
    labels=[y1, y2],
    dataset_ids=["mouse1", "mouse2"],
)

Decoders

CEBRA decoders (KNNDecoder, L1LinearRegressor) are standalone sklearn estimators that expect 2-D input. Flatten the embedding first:

import cebra

emb      = model.transform(X)                  # (ntrial, ntime, 3)
emb_flat = emb.reshape(-1, emb.shape[-1])       # (ntrial*ntime, 3)
y_flat   = y.reshape(-1)                        # (ntrial*ntime,)

decoder = cebra.KNNDecoder()
decoder.fit(emb_flat, y_flat)
score = decoder.score(emb_flat, y_flat)

Multi-session

Pass X as a list of epoch arrays (one per session):

X = [
    np.random.randn(30, 100, 64).astype(np.float32),   # session 0
    np.random.randn(25,  80, 48).astype(np.float32),   # session 1
]
y_cont = [np.random.randn(30, 100, 16).astype(np.float32),
          np.random.randn(25,  80, 16).astype(np.float32)]
y_disc = [np.zeros((30, 100), dtype=np.int64),
          np.zeros((25,  80), dtype=np.int64)]

model = TrialCEBRA(conditional="delta", output_dimension=3, max_iterations=1000)
model.fit(X, y_disc, y_cont)

"delta" and "time_delta" are supported for multi-session; "time" raises NotImplementedError.


Contributing

uv sync --dev
uv run pre-commit install --hook-type pre-commit --hook-type pre-push
uv run pytest tests/ -v

Release: git tag vX.X.X && git push origin vX.X.X.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trialcebra-0.0.4.tar.gz (3.1 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

trialcebra-0.0.4-py3-none-any.whl (33.7 kB view details)

Uploaded Python 3

File details

Details for the file trialcebra-0.0.4.tar.gz.

File metadata

  • Download URL: trialcebra-0.0.4.tar.gz
  • Upload date:
  • Size: 3.1 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.11.8 {"installer":{"name":"uv","version":"0.11.8","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for trialcebra-0.0.4.tar.gz
Algorithm Hash digest
SHA256 7ca2c07538418ad2a42a3b70e94583f0f9cab24ee9c5ed014c61a011765b848f
MD5 f26dd585dcd6290d69a15282f31f62c7
BLAKE2b-256 d6359a023e482f5acde5477761a9d329bc2fcbd1a6a13e64371c0e33acf9f87a

See more details on using hashes here.

File details

Details for the file trialcebra-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: trialcebra-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 33.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.11.8 {"installer":{"name":"uv","version":"0.11.8","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for trialcebra-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 3dadc9288427474d7c3ccd1b4b530ce0bc795567b9f8351e020d8781bfdf9ec7
MD5 ae8a7037c30120976205a2af183e9357
BLAKE2b-256 762dd052d5612b44dd24432f10fcea733af54202341e84f07b4ab132d4d418ae

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page