Trial-aware contrastive learning wrapper for CEBRA
Project description
TrialCEBRA
[English | 中文]
Trial-aware contrastive learning for CEBRA. Pass 3-D epoch-format data (ntrial, ntime, nneuro) — trial boundaries are respected automatically.
Installation
Step 1 — Install PyTorch from pytorch.org.
Step 2
pip install TrialCEBRA
Quick Start
import numpy as np
from trial_cebra import TrialCEBRA
X = np.random.randn(40, 50, 64).astype(np.float32) # (ntrial, ntime, nneuro)
y = np.random.randn(40, 16).astype(np.float32) # (ntrial, stim_dim)
model = TrialCEBRA(
model_architecture = "offset10-model",
conditional = "delta",
time_offsets = 5,
delta = 0.3,
output_dimension = 3,
max_iterations = 1000,
batch_size = 512,
)
model.fit(X, y)
emb = model.transform(X) # (ntrial, ntime, 3) — shape preserved
2-D flat input (N, nneuro) falls back to native CEBRA behavior unchanged.
Conditionals
conditional |
Trial selection | Within-trial | y shape |
|---|---|---|---|
"time" |
Uniform random | ±time_offsets window |
not required |
"delta" |
Gaussian similarity on y | Uniform (free) | (ntrial, nd) or (ntrial, ntime, nd) |
"time_delta" |
Joint argmin across trials | ±time_offsets window |
(ntrial, ntime, nd) |
Pass a discrete integer label (e.g. y_disc of dtype int64) alongside a continuous label to enable class-conditional trial selection for "delta".
Key Parameters
TrialCEBRA(
conditional = "delta", # "time" | "delta" | "time_delta"
time_offsets = 10, # half-width of within-trial time window
delta = 0.1, # Gaussian noise std for trial similarity
sample_fix_trial = False, # True: fix trial pairing at init
sample_exclude_intrial = True, # True: positives always from a different trial
sample_prior = "balanced", # "balanced" (default) or "uniform"
output_dimension = 3,
# ... all other cebra.CEBRA kwargs accepted
)
After fit, the distribution is accessible at model.distribution_.
Transform
transform() preserves input dimensionality:
emb = model.transform(X) # (ntrial, ntime, 3) if X is (ntrial, ntime, nneuro)
emb = model.transform(X_flat) # (N, 3) if X_flat is (N, nneuro)
emb = model.transform_epochs(X) # strict 3-D variant — raises if X.ndim != 3
Metrics
All metric methods accept epoch-format (ntrial, ntime, nneuro) data directly:
loss = model.infonce_loss(X, y)
gof = model.goodness_of_fit_score(X, y)
hist = model.goodness_of_fit_history() # training curve, no X needed
# Consistency score: accepts 3-D embedding lists
emb1 = model.transform(X1) # (ntrial, ntime, 3)
emb2 = model.transform(X2)
scores, pairs, ids = TrialCEBRA.consistency_score(
[emb1, emb2], between="runs"
)
# between-datasets: pass labels=(ntrial, ntime) or (ntrial*ntime,)
scores, pairs, ids = TrialCEBRA.consistency_score(
[emb1, emb2],
between="datasets",
labels=[y1, y2],
dataset_ids=["mouse1", "mouse2"],
)
Decoders
CEBRA decoders (KNNDecoder, L1LinearRegressor) are standalone sklearn estimators that expect 2-D input. Flatten the embedding first:
import cebra
emb = model.transform(X) # (ntrial, ntime, 3)
emb_flat = emb.reshape(-1, emb.shape[-1]) # (ntrial*ntime, 3)
y_flat = y.reshape(-1) # (ntrial*ntime,)
decoder = cebra.KNNDecoder()
decoder.fit(emb_flat, y_flat)
score = decoder.score(emb_flat, y_flat)
Multi-session
Pass X as a list of epoch arrays (one per session):
X = [
np.random.randn(30, 100, 64).astype(np.float32), # session 0
np.random.randn(25, 80, 48).astype(np.float32), # session 1
]
y_cont = [np.random.randn(30, 100, 16).astype(np.float32),
np.random.randn(25, 80, 16).astype(np.float32)]
y_disc = [np.zeros((30, 100), dtype=np.int64),
np.zeros((25, 80), dtype=np.int64)]
model = TrialCEBRA(conditional="delta", output_dimension=3, max_iterations=1000)
model.fit(X, y_disc, y_cont)
"delta" and "time_delta" are supported for multi-session; "time" raises NotImplementedError.
Contributing
uv sync --dev
uv run pre-commit install --hook-type pre-commit --hook-type pre-push
uv run pytest tests/ -v
Release: git tag vX.X.X && git push origin vX.X.X.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file trialcebra-0.0.5a0.tar.gz.
File metadata
- Download URL: trialcebra-0.0.5a0.tar.gz
- Upload date:
- Size: 3.1 MB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.11.8 {"installer":{"name":"uv","version":"0.11.8","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
90758af319580d3ddf5033ac8447bfa541b9213d3dedba6726aba62d83e427bb
|
|
| MD5 |
53d04f8287624956eba50029bb2b7eb0
|
|
| BLAKE2b-256 |
519d17d3e19bc5c3f0a5a8bbd2f48017d8c9649a8001a4607db284699a71c0a0
|
File details
Details for the file trialcebra-0.0.5a0-py3-none-any.whl.
File metadata
- Download URL: trialcebra-0.0.5a0-py3-none-any.whl
- Upload date:
- Size: 33.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.11.8 {"installer":{"name":"uv","version":"0.11.8","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8564efd5bca379de7c7811ab7993d60ce6e0b7332184006ef1f1d70325b8fc84
|
|
| MD5 |
df4622f1a973262a9ac76fa827b85171
|
|
| BLAKE2b-256 |
dadda226e5bd89fef83c936d8b2be2463277d076419a51e715cb26e7f1e73097
|