Skip to main content

Dataloaders for training models on huge single-cell datasets

Project description

Cell Load

A PyTorch-based data loading library for single-cell perturbation data.

Features

  • Load perturbation data from H5 files (AnnData format)
  • Support for multiple cell types per dataset
  • Configurable mapping strategies for control cell selection
  • Zero-shot and few-shot learning support
  • Cell barcode tracking (optional)
  • Preprocessing utilities for quality control and data filtering

Installation

uv pip install cell-load

Quick Start

1. Create a TOML configuration file:

The TOML configuration file defines your datasets, training splits, and experimental setup. Here's the format:

# Dataset paths - maps dataset names to their directories
[datasets]
replogle = "/path/to/replogle_dataset/" # ADDS ALL h5 or h5ad files in this folder to training
jurkat = "/path/to/jurkat_dataset/"

# Training specifications
# All cell types in a dataset automatically go into training (excluding zeroshot/fewshot overrides)
[training]
replogle = "train"
jurkat = "train"

# Zeroshot specifications - entire cell types go to val or test
[zeroshot]
"replogle.jurkat" = "test"
"jurkat.rpe1" = "val"

# Fewshot specifications - explicit perturbation lists
[fewshot]

[fewshot."replogle.rpe1"]
val = ["AARS"]
test = ["AARS", "NUP107", "RPUSD4"]  # can overlap with val
# train gets all other perturbations automatically

[fewshot."jurkat.k562"]
val = ["GENE1", "GENE2"]
test = ["GENE3", "GENE4"]

TOML Configuration Format

[datasets]: Maps dataset names to their directory paths

  • Each dataset should contain H5 files (one per cell type)
  • Files should be named like cell_type.h5 or cell_type.h5ad

[training]: Specifies which datasets are used for training

  • Set to "train" to include all cell types in training (except those in zeroshot/fewshot)

[zeroshot]: Holds out entire cell types for testing

  • Format: "dataset.cell_type" = "split"
  • Split can be "val" or "test"
  • Example: "replogle.jurkat" = "test" holds out all Jurkat cells from training

[fewshot]: Holds out specific perturbations within cell types

  • Format: [fewshot."dataset.cell_type"]
  • val = ["pert1", "pert2"]: Perturbations for validation
  • test = ["pert3", "pert4"]: Perturbations for testing
  • Remaining perturbations go to training

It is worth noting that control cell mapping is only done withi the same file (e.g., a perturbed cell will not get mapped to a control cell from a different h5 file, even if it has matched covariates).

2. Command Line Usage

The most common parameters for data loading are:

# Basic required parameters
data.kwargs.toml_config_path=/path/to/config.toml
data.kwargs.embed_key=X_hvg
data.kwargs.num_workers=24
data.kwargs.batch_col=gem_group
data.kwargs.pert_col=gene
data.kwargs.cell_type_key=cell_type
data.kwargs.control_pert=non-targeting

# Optional parameters
data.kwargs.barcode=true  # Enable cell barcode output
data.kwargs.perturbation_features_file=/path/to/gene_embeddings.pt
data.kwargs.output_space=gene
data.kwargs.basal_mapping_strategy=random
data.kwargs.n_basal_samples=1
data.kwargs.should_yield_control_cells=true
data.kwargs.val_subsample_batches=32  # STATE/Hydra knob to reduce validation batches
data.kwargs.use_consecutive_loading=true  # Faster IO, especially with output_space=all

These plug in as hydra configurable settings in the STATE repository. When use_consecutive_loading=true, data should be grouped by (cell_type, perturbation/condition) so each pair is contiguous within each H5/H5AD file. For example, avoid sequences like (ct1, pert1), (ct1, pert2), (ct1, pert1); Cell Load will raise an error if this assumption is violated (checked per file).

3. Standalone Programmatic Usage

from cell_load.data_modules import PerturbationDataModule

dm = PerturbationDataModule(
    # Required parameters
    toml_config_path="/path/to/config.toml",
    embed_key="X_hvg",
    num_workers=24,
    batch_col="gem_group",
    pert_col="gene",
    cell_type_key="cell_type",
    control_pert="non-targeting",
    
    # Optional parameters
    barcode=True,  # Enable cell barcode output
    perturbation_features_file="/path/to/gene_embeddings.pt",
    output_space="gene",
    basal_mapping_strategy="random",
    n_basal_samples=1,
    should_yield_control_cells=True,
    val_subsample_fraction=0.25,
    use_consecutive_loading=True,
    batch_size=128,
)
dm.setup()

# Get training data
train_loader = dm.train_dataloader()
for batch in train_loader:
    # batch contains:
    # - pert_cell_emb: perturbed cell embeddings
    # - ctrl_cell_emb: control cell embeddings
    # - pert_emb: perturbation one-hot encodings or embeddings
    # - pert_name: perturbation names
    # - cell_type: cell types
    # - batch: batch information
    # - pert_cell_barcode: cell barcodes (if barcode=True)
    # - ctrl_cell_barcode: control cell barcodes (if barcode=True)
    pass

Preprocessing

Cell Load provides several preprocessing utilities to help with data quality control and filtering before training.

Quality Control: On-Target Knockdown Filtering

The filter_on_target_knockdown function filters perturbation data based on the effectiveness of gene knockdown. This is crucial for ensuring that your perturbation experiments actually worked as intended.

import anndata
from cell_load.utils.data_utils import filter_on_target_knockdown

# Load your AnnData object
adata = anndata.read_h5ad("your_data.h5ad")

# Apply quality control filtering
filtered_adata = filter_on_target_knockdown(
    adata=adata,
    perturbation_column="gene",           # Column in obs containing perturbation info
    control_label="non-targeting",        # Label for control cells
    residual_expression=0.30,             # Perturbation-level threshold (30% residual = 70% knockdown)
    cell_residual_expression=0.50,        # Cell-level threshold (50% residual = 50% knockdown)
    min_cells=30,                         # Minimum cells per perturbation after filtering
    layer=None,                           # Use adata.X (or specify a layer)
    var_gene_name="gene_name"             # Column in var containing gene names
)

print(f"Original cells: {adata.n_obs}")
print(f"Filtered cells: {filtered_adata.n_obs}")
print(f"Removed {adata.n_obs - filtered_adata.n_obs} cells due to poor knockdown")

How the Filtering Works

The filter_on_target_knockdown function performs a three-stage filtering process:

  1. Perturbation-level filtering: Keeps only perturbations where the average knockdown ≥ (1 - residual_expression)
  2. Cell-level filtering: Within those perturbations, keeps only cells where knockdown ≥ (1 - cell_residual_expression)
  3. Minimum cell count: Discards perturbations that have fewer than min_cells cells remaining after stages 1-2

Control cells are always preserved regardless of these criteria.

Parameters

  • residual_expression (default: 0.30): Perturbation-level threshold. 0.30 means 70% knockdown required.
  • cell_residual_expression (default: 0.50): Cell-level threshold. 0.50 means 50% knockdown required per cell.
  • min_cells (default: 30): Minimum number of cells per perturbation after filtering.
  • layer: Use a specific layer instead of adata.X (e.g., "counts", "log1p").
  • var_gene_name: Column in adata.var containing gene names (default: "gene_name").

Other Preprocessing Utilities

Check Individual Perturbation Effectiveness

from cell_load.utils.data_utils import is_on_target_knockdown

# Check if a specific perturbation worked
is_effective = is_on_target_knockdown(
    adata=adata,
    target_gene="GENE1",
    perturbation_column="gene",
    control_label="non-targeting",
    residual_expression=0.30
)
print(f"GENE1 knockdown effective: {is_effective}")

Data Type Detection

from cell_load.utils.data_utils import suspected_discrete_torch, suspected_log_torch

# Check if data appears to be raw counts
is_discrete = suspected_discrete_torch(torch_tensor_data)
print(f"Data appears to be discrete counts: {is_discrete}")

# Check if data is log-transformed
is_logged = suspected_log_torch(torch_tensor_data)
print(f"Data appears to be log-transformed: {is_logged}")

Gene Name Indexing

from cell_load.utils.data_utils import set_var_index_to_col

# Set the var index to use gene names from a specific column
adata = set_var_index_to_col(adata, col="gene_name")

Preprocessing Workflow Example

Here's a typical preprocessing workflow:

import anndata
from cell_load.utils.data_utils import filter_on_target_knockdown, set_var_index_to_col

# 1. Load data
adata = anndata.read_h5ad("raw_data.h5ad")

# 2. Set up gene names as index (if needed)
adata = set_var_index_to_col(adata, col="gene_name")

# 3. Apply quality control filtering
filtered_adata = filter_on_target_knockdown(
    adata=adata,
    perturbation_column="gene",
    control_label="non-targeting",
    residual_expression=0.30,
    cell_residual_expression=0.50,
    min_cells=30
)

# 4. Save filtered data
filtered_adata.write_h5ad("filtered_data.h5ad")

# 5. Use in your TOML config
# [datasets]
# my_dataset = "/path/to/filtered_data.h5ad"

Parameter Reference

Required Parameters

  • toml_config_path: Path to the TOML configuration file defining datasets and splits
  • embed_key: Key in the H5 file's obsm section to use for cell embeddings (e.g., "X_hvg", "X_state")
  • pert_col: Column name in obs for perturbation information (default: "gene")
  • cell_type_key: Column name in obs for cell type information (default: "cell_type")
  • batch_col: Column name in obs for batch/plate information (default: "gem_group")
  • control_pert: Value in pert_col that represents control cells (default: "non-targeting")

Optional Parameters

  • barcode: If true, include cell barcodes in output (default: false)
  • perturbation_features_file: Path to .pt file containing pre-computed gene embeddings
  • output_space: Output space for model predictions ("gene" or "all", default: "gene")
  • basal_mapping_strategy: Strategy for mapping perturbed cells to controls ("batch" or "random", default: "random")
  • n_basal_samples: Number of control cells to sample per perturbed cell (default: 1)
  • should_yield_control_cells: Include control cells in output (default: true)
  • num_workers: Number of workers for data loading (default: 8)
  • batch_size: Batch size for training (default: 128)
  • val_subsample_batches: STATE/Hydra parameter to subsample validation batches. In standalone cell-load, use val_subsample_fraction.
  • val_subsample_fraction: Fraction of validation subsets to keep (e.g., 0.01 keeps ~1% of val_datasets)
  • use_consecutive_loading: Groups batches by consecutive indices for faster H5 reads (especially for output_space="all"). Requires data grouped contiguously by (cell_type, perturbation/condition) within each file, and raises an error if this assumption is violated.

Usage

When creating the data module programmatically:

from cell_load.data_modules import PerturbationDataModule

dm = PerturbationDataModule(
    toml_config_path="config.toml",
    # ... other parameters
)

Advanced Configuration

Zero-shot Learning

To set up zero-shot learning (entire cell types held out for testing):

[zeroshot]
"dataset.cell_type" = "test"

Few-shot Learning

To set up few-shot learning (specific perturbations held out):

[fewshot."dataset.cell_type"]
val = ["pert1", "pert2"]
test = ["pert3", "pert4"]

Custom Perturbation Features

To use pre-computed gene embeddings instead of one-hot encodings:

data.kwargs.perturbation_features_file=/path/to/gene_embeddings.pt

The .pt file should contain a dictionary mapping gene names to embedding vectors.

Dataset Management

Getting Datasets

Currently, Cell Load expects datasets to be in H5/AnnData format and stored locally. Users need to:

  1. Obtain datasets from their original sources (e.g., published papers, repositories)
  2. Convert to AnnData format if not already in that format
  3. Apply preprocessing using the utilities described above
  4. Organize by cell type with one H5 file per cell type

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cell_load-0.10.4.tar.gz (57.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cell_load-0.10.4-py3-none-any.whl (52.0 kB view details)

Uploaded Python 3

File details

Details for the file cell_load-0.10.4.tar.gz.

File metadata

  • Download URL: cell_load-0.10.4.tar.gz
  • Upload date:
  • Size: 57.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.11.25 {"installer":{"name":"uv","version":"0.11.25","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for cell_load-0.10.4.tar.gz
Algorithm Hash digest
SHA256 d1c883e9f63c51bfff7e015670b16af0d520921181d3583bcb862bfbca2e8d96
MD5 d744bf60033d046cdfc543d4a49519e9
BLAKE2b-256 86d0595622881628d29df25879ea590d81992ad6bee0eac42b580c71cfcb54a5

See more details on using hashes here.

File details

Details for the file cell_load-0.10.4-py3-none-any.whl.

File metadata

  • Download URL: cell_load-0.10.4-py3-none-any.whl
  • Upload date:
  • Size: 52.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.11.25 {"installer":{"name":"uv","version":"0.11.25","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for cell_load-0.10.4-py3-none-any.whl
Algorithm Hash digest
SHA256 e48cfda0b7e93f60bb3a7c9269301a3d892f03f241f0dec51a6f7327d2d9e726
MD5 55d6ddbea1107a2148fabcef65f3eedb
BLAKE2b-256 1bf482591a001d2cfb4feb36ef918a26ae8ed2415fac2d57b0851f88211b0ace

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page