Dataloaders for training models on huge single-cell datasets
Project description
Cell Load
A PyTorch-based data loading library for single-cell perturbation data.
Features
- Load perturbation data from H5 files (AnnData format)
- Support for multiple cell types per dataset
- Configurable mapping strategies for control cell selection
- Zero-shot and few-shot learning support
- Cell barcode tracking (optional)
- Preprocessing utilities for quality control and data filtering
Installation
uv pip install cell-load
Quick Start
1. Create a TOML configuration file:
The TOML configuration file defines your datasets, training splits, and experimental setup. Here's the format:
# Dataset paths - maps dataset names to their directories
[datasets]
replogle = "/path/to/replogle_dataset/" # ADDS ALL h5 or h5ad files in this folder to training
jurkat = "/path/to/jurkat_dataset/"
# Training specifications
# All cell types in a dataset automatically go into training (excluding zeroshot/fewshot overrides)
[training]
replogle = "train"
jurkat = "train"
# Zeroshot specifications - entire cell types go to val or test
[zeroshot]
"replogle.jurkat" = "test"
"jurkat.rpe1" = "val"
# Fewshot specifications - explicit perturbation lists
[fewshot]
[fewshot."replogle.rpe1"]
val = ["AARS"]
test = ["AARS", "NUP107", "RPUSD4"] # can overlap with val
# train gets all other perturbations automatically
[fewshot."jurkat.k562"]
val = ["GENE1", "GENE2"]
test = ["GENE3", "GENE4"]
TOML Configuration Format
[datasets]: Maps dataset names to their directory paths
- Each dataset should contain H5 files (one per cell type)
- Files should be named like
cell_type.h5orcell_type.h5ad
[training]: Specifies which datasets are used for training
- Set to
"train"to include all cell types in training (except those in zeroshot/fewshot)
[zeroshot]: Holds out entire cell types for testing
- Format:
"dataset.cell_type" = "split" - Split can be
"val"or"test" - Example:
"replogle.jurkat" = "test"holds out all Jurkat cells from training
[fewshot]: Holds out specific perturbations within cell types
- Format:
[fewshot."dataset.cell_type"] val = ["pert1", "pert2"]: Perturbations for validationtest = ["pert3", "pert4"]: Perturbations for testing- Remaining perturbations go to training
It is worth noting that control cell mapping is only done withi the same file (e.g., a perturbed cell will not get mapped to a control cell from a different h5 file, even if it has matched covariates).
2. Command Line Usage
The most common parameters for data loading are:
# Basic required parameters
data.kwargs.toml_config_path=/path/to/config.toml
data.kwargs.embed_key=X_hvg
data.kwargs.num_workers=24
data.kwargs.batch_col=gem_group
data.kwargs.pert_col=gene
data.kwargs.cell_type_key=cell_type
data.kwargs.control_pert=non-targeting
# Optional parameters
data.kwargs.barcode=true # Enable cell barcode output
data.kwargs.perturbation_features_file=/path/to/gene_embeddings.pt
data.kwargs.output_space=gene
data.kwargs.basal_mapping_strategy=random
data.kwargs.n_basal_samples=1
data.kwargs.should_yield_control_cells=true
These plug in as hydra configurable settings in the STATE repository.
3. Standalone Programmatic Usage
from cell_load.data_modules import PerturbationDataModule
dm = PerturbationDataModule(
# Required parameters
toml_config_path="/path/to/config.toml",
embed_key="X_hvg",
num_workers=24,
batch_col="gem_group",
pert_col="gene",
cell_type_key="cell_type",
control_pert="non-targeting",
# Optional parameters
barcode=True, # Enable cell barcode output
perturbation_features_file="/path/to/gene_embeddings.pt",
output_space="gene",
basal_mapping_strategy="random",
n_basal_samples=1,
should_yield_control_cells=True,
batch_size=128,
)
dm.setup()
# Get training data
train_loader = dm.train_dataloader()
for batch in train_loader:
# batch contains:
# - pert_cell_emb: perturbed cell embeddings
# - ctrl_cell_emb: control cell embeddings
# - pert_emb: perturbation one-hot encodings or embeddings
# - pert_name: perturbation names
# - cell_type: cell types
# - batch: batch information
# - pert_cell_barcode: cell barcodes (if barcode=True)
# - ctrl_cell_barcode: control cell barcodes (if barcode=True)
pass
Preprocessing
Cell Load provides several preprocessing utilities to help with data quality control and filtering before training.
Quality Control: On-Target Knockdown Filtering
The filter_on_target_knockdown function filters perturbation data based on the effectiveness of gene knockdown. This is crucial for ensuring that your perturbation experiments actually worked as intended.
import anndata
from cell_load.utils.data_utils import filter_on_target_knockdown
# Load your AnnData object
adata = anndata.read_h5ad("your_data.h5ad")
# Apply quality control filtering
filtered_adata = filter_on_target_knockdown(
adata=adata,
perturbation_column="gene", # Column in obs containing perturbation info
control_label="non-targeting", # Label for control cells
residual_expression=0.30, # Perturbation-level threshold (30% residual = 70% knockdown)
cell_residual_expression=0.50, # Cell-level threshold (50% residual = 50% knockdown)
min_cells=30, # Minimum cells per perturbation after filtering
layer=None, # Use adata.X (or specify a layer)
var_gene_name="gene_name" # Column in var containing gene names
)
print(f"Original cells: {adata.n_obs}")
print(f"Filtered cells: {filtered_adata.n_obs}")
print(f"Removed {adata.n_obs - filtered_adata.n_obs} cells due to poor knockdown")
How the Filtering Works
The filter_on_target_knockdown function performs a three-stage filtering process:
- Perturbation-level filtering: Keeps only perturbations where the average knockdown ≥ (1 -
residual_expression) - Cell-level filtering: Within those perturbations, keeps only cells where knockdown ≥ (1 -
cell_residual_expression) - Minimum cell count: Discards perturbations that have fewer than
min_cellscells remaining after stages 1-2
Control cells are always preserved regardless of these criteria.
Parameters
residual_expression(default: 0.30): Perturbation-level threshold. 0.30 means 70% knockdown required.cell_residual_expression(default: 0.50): Cell-level threshold. 0.50 means 50% knockdown required per cell.min_cells(default: 30): Minimum number of cells per perturbation after filtering.layer: Use a specific layer instead ofadata.X(e.g., "counts", "log1p").var_gene_name: Column inadata.varcontaining gene names (default: "gene_name").
Other Preprocessing Utilities
Check Individual Perturbation Effectiveness
from cell_load.utils.data_utils import is_on_target_knockdown
# Check if a specific perturbation worked
is_effective = is_on_target_knockdown(
adata=adata,
target_gene="GENE1",
perturbation_column="gene",
control_label="non-targeting",
residual_expression=0.30
)
print(f"GENE1 knockdown effective: {is_effective}")
Data Type Detection
from cell_load.utils.data_utils import suspected_discrete_torch, suspected_log_torch
# Check if data appears to be raw counts
is_discrete = suspected_discrete_torch(torch_tensor_data)
print(f"Data appears to be discrete counts: {is_discrete}")
# Check if data is log-transformed
is_logged = suspected_log_torch(torch_tensor_data)
print(f"Data appears to be log-transformed: {is_logged}")
Gene Name Indexing
from cell_load.utils.data_utils import set_var_index_to_col
# Set the var index to use gene names from a specific column
adata = set_var_index_to_col(adata, col="gene_name")
Preprocessing Workflow Example
Here's a typical preprocessing workflow:
import anndata
from cell_load.utils.data_utils import filter_on_target_knockdown, set_var_index_to_col
# 1. Load data
adata = anndata.read_h5ad("raw_data.h5ad")
# 2. Set up gene names as index (if needed)
adata = set_var_index_to_col(adata, col="gene_name")
# 3. Apply quality control filtering
filtered_adata = filter_on_target_knockdown(
adata=adata,
perturbation_column="gene",
control_label="non-targeting",
residual_expression=0.30,
cell_residual_expression=0.50,
min_cells=30
)
# 4. Save filtered data
filtered_adata.write_h5ad("filtered_data.h5ad")
# 5. Use in your TOML config
# [datasets]
# my_dataset = "/path/to/filtered_data.h5ad"
Parameter Reference
Required Parameters
toml_config_path: Path to the TOML configuration file defining datasets and splitsembed_key: Key in the H5 file'sobsmsection to use for cell embeddings (e.g., "X_hvg", "X_state")pert_col: Column name inobsfor perturbation information (default: "gene")cell_type_key: Column name inobsfor cell type information (default: "cell_type")batch_col: Column name inobsfor batch/plate information (default: "gem_group")control_pert: Value inpert_colthat represents control cells (default: "non-targeting")
Optional Parameters
barcode: Iftrue, include cell barcodes in output (default:false)perturbation_features_file: Path to .pt file containing pre-computed gene embeddingsoutput_space: Output space for model predictions ("gene" or "all", default: "gene")basal_mapping_strategy: Strategy for mapping perturbed cells to controls ("batch" or "random", default: "random")n_basal_samples: Number of control cells to sample per perturbed cell (default: 1)should_yield_control_cells: Include control cells in output (default:true)num_workers: Number of workers for data loading (default: 8)batch_size: Batch size for training (default: 128)
Usage
When creating the data module programmatically:
from cell_load.data_modules import PerturbationDataModule
dm = PerturbationDataModule(
toml_config_path="config.toml",
# ... other parameters
)
Advanced Configuration
Zero-shot Learning
To set up zero-shot learning (entire cell types held out for testing):
[zeroshot]
"dataset.cell_type" = "test"
Few-shot Learning
To set up few-shot learning (specific perturbations held out):
[fewshot."dataset.cell_type"]
val = ["pert1", "pert2"]
test = ["pert3", "pert4"]
Custom Perturbation Features
To use pre-computed gene embeddings instead of one-hot encodings:
data.kwargs.perturbation_features_file=/path/to/gene_embeddings.pt
The .pt file should contain a dictionary mapping gene names to embedding vectors.
Dataset Management
Getting Datasets
Currently, Cell Load expects datasets to be in H5/AnnData format and stored locally. Users need to:
- Obtain datasets from their original sources (e.g., published papers, repositories)
- Convert to AnnData format if not already in that format
- Apply preprocessing using the utilities described above
- Organize by cell type with one H5 file per cell type
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file cell_load-0.8.2.tar.gz.
File metadata
- Download URL: cell_load-0.8.2.tar.gz
- Upload date:
- Size: 45.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.8.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5e82e9fe98d9ed6f4b07311f41ef71996ab5b664f58ca0e101213fa422343614
|
|
| MD5 |
3653513627819f22d6f9fc095b05c9d5
|
|
| BLAKE2b-256 |
0463fa71fbe2a52f4aa2597a13384ce6b71f97b6cdad3ee9dd4f0964a66612df
|
File details
Details for the file cell_load-0.8.2-py3-none-any.whl.
File metadata
- Download URL: cell_load-0.8.2-py3-none-any.whl
- Upload date:
- Size: 39.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.8.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
fc72bf283224200e11212097b7482713834e13c889b89bd95fa9763024eee41a
|
|
| MD5 |
cc561d32aa902b70c7dc31cd75379cd3
|
|
| BLAKE2b-256 |
b11e4b1b656043bd9c11b09fd58ff26c749c9f4a69d68f5180cd46ad028968f8
|