Utility for loading CellMap data for machine learning training, utilizing PyTorch, Xarray, TensorStore, and PyDantic.
Project description
CellMap-Data
A comprehensive PyTorch-based data loading and preprocessing library for CellMap biological imaging datasets, designed for efficient machine learning training on large-scale 2D/3D volumetric data.
Overview
CellMap-Data is a specialized data loading utility that bridges the gap between large biological imaging datasets and machine learning frameworks. It provides efficient, memory-optimized data loading for training deep learning models on cell microscopy data, with support for multi-class segmentation, spatial transformations, and advanced augmentation techniques.
Key Features
- 🔬 Biological Data Optimized: Native support for multiscale biological imaging formats (OME-NGFF/Zarr)
- ⚡ High-Performance Loading: Efficient data streaming with TensorStore backend and optimized PyTorch integration
- 🎯 Flexible Target Construction: Support for multi-class segmentation with mutually exclusive class relationships
- 🔄 Advanced Augmentations: Comprehensive spatial and value transformations for robust model training
- 📊 Smart Sampling: Weighted sampling strategies and validation set management
- 🚀 Scalable Architecture: Memory-efficient handling of datasets larger than available RAM
- 🔧 Production Ready: Thread-safe, multiprocess-compatible with extensive test coverage
Installation
pip install cellmap-data
Dependencies
CellMap-Data leverages several powerful libraries:
- PyTorch: Neural network training and tensor operations
- TensorStore: High-performance array storage and retrieval
- Xarray: Labeled multi-dimensional arrays with metadata
- PyDantic: Data validation and settings management
- Zarr: Chunked, compressed array storage
Quick Start
Basic Dataset Setup
from cellmap_data import CellMapDataset
# Define input and target array specifications
input_arrays = {
"raw": {
"shape": (64, 64, 64), # Training patch size
"scale": (8, 8, 8), # Voxel resolution in nm
}
}
target_arrays = {
"segmentation": {
"shape": (64, 64, 64),
"scale": (8, 8, 8),
}
}
# Create dataset
dataset = CellMapDataset(
raw_path="/path/to/raw/data.zarr",
target_path="/path/to/labels/data.zarr",
classes=["mitochondria", "endoplasmic_reticulum", "nucleus"],
input_arrays=input_arrays,
target_arrays=target_arrays,
is_train=True
)
Data Loading with Augmentations
from cellmap_data import CellMapDataLoader
from cellmap_data.transforms import RandomContrast, GaussianNoise, Binarize
import torchvision.transforms.v2 as T
# Define spatial transformations
spatial_transforms = {
"mirror": {"axes": {"x": 0.5, "y": 0.5, "z": 0.2}},
"rotate": {"axes": {"z": [-30, 30]}},
"transpose": {"axes": ["x", "y"]}
}
# Define value transformations
raw_value_transforms = T.Compose([
T.ToDtype(torch.float, scale=True), # Normalize to [0,1] and convert to float
GaussianNoise(std=0.05), # Add noise for augmentation
RandomContrast((0.8, 1.2)), # Vary contrast
])
target_value_transforms = T.Compose([
Binarize(threshold=0.5), # Convert to binary masks
T.ToDtype(torch.float32) # Ensure correct dtype
])
# Create dataset with transforms
dataset = CellMapDataset(
raw_path="/path/to/raw/data.zarr",
target_path="/path/to/labels/data.zarr",
classes=["mitochondria", "endoplasmic_reticulum", "nucleus"],
input_arrays=input_arrays,
target_arrays=target_arrays,
spatial_transforms=spatial_transforms,
raw_value_transforms=raw_value_transforms,
target_value_transforms=target_value_transforms,
is_train=True
)
# Configure data loader
loader = CellMapDataLoader(
dataset,
batch_size=4,
num_workers=8,
weighted_sampler=True, # Balance classes automatically
is_train=True
)
# Training loop
for batch in loader:
inputs = batch["raw"] # Shape: [batch, channels, z, y, x]
targets = batch["segmentation"] # Multi-class targets
# Your training code here
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
Multi-Dataset Training
from cellmap_data import CellMapDataSplit
# Define datasets from CSV or dictionary
datasplit = CellMapDataSplit(
csv_path="path/to/datasplit.csv",
classes=["mitochondria", "er", "nucleus"],
input_arrays=input_arrays,
target_arrays=target_arrays,
spatial_transforms={
"mirror": {"axes": {"x": 0.5, "y": 0.5}},
"rotate": {"axes": {"z": [-180, 180]}},
"transpose": {"axes": ["x", "y"]}
}
)
# Access combined datasets
train_loader = CellMapDataLoader(
datasplit.train_datasets_combined,
batch_size=8,
weighted_sampler=True
)
val_loader = CellMapDataLoader(
datasplit.validation_datasets_combined,
batch_size=16,
is_train=False
)
Core Components
CellMapDataset
The foundational dataset class that handles individual image volumes:
dataset = CellMapDataset(
raw_path="path/to/raw.zarr",
target_path="path/to/gt.zarr",
classes=["class1", "class2"],
input_arrays=input_arrays,
target_arrays=target_arrays,
is_train=True,
pad=True, # Pad arrays to requested size if needed
device="cuda"
)
Key Features:
- Automatic 2D/3D handling and slicing
- Multiscale data support
- Memory-efficient random cropping
- Class balancing and weighting
- Spatial transformation pipeline
CellMapMultiDataset
Combines multiple datasets for training across different samples:
from cellmap_data import CellMapMultiDataset
multi_dataset = CellMapMultiDataset(
classes=classes,
input_arrays=input_arrays,
target_arrays=target_arrays,
datasets=[dataset1, dataset2, dataset3]
)
# Weighted sampling across datasets
sampler = multi_dataset.get_weighted_sampler(batch_size=4)
CellMapDataLoader
High-performance data loader built on PyTorch's optimized DataLoader:
loader = CellMapDataLoader(
dataset,
batch_size=32,
num_workers=12,
weighted_sampler=True,
device="cuda",
prefetch_factor=4, # Preload batches for better GPU utilization
persistent_workers=True, # Keep workers alive between epochs
pin_memory=True, # Fast CPU-to-GPU transfer
iterations_per_epoch=1000 # For large datasets
)
# Optimized GPU memory transfer
loader.to("cuda", non_blocking=True)
Optimizations (powered by PyTorch DataLoader):
- Prefetch Factor: Background data loading to maximize GPU utilization
- Pin Memory: Fast CPU-to-GPU transfers via pinned memory (auto-enabled on CUDA)
- Persistent Workers: Reduced overhead by keeping workers alive between epochs
- PyTorch's Optimized Multiprocessing: Battle-tested parallel data loading
- Smart Defaults: Automatic optimization based on hardware configuration
See DataLoader Optimization Guide for performance tuning tips.
CellMapDataSplit
Manages train/validation splits with configuration:
datasplit = CellMapDataSplit(
dataset_dict={
"train": [
{"raw": "path1/raw.zarr", "gt": "path1/gt.zarr"},
{"raw": "path2/raw.zarr", "gt": "path2/gt.zarr"}
],
"validate": [
{"raw": "path3/raw.zarr", "gt": "path3/gt.zarr"}
]
},
classes=classes,
input_arrays=input_arrays,
target_arrays=target_arrays
)
Advanced Features
Spatial Transformations
Comprehensive augmentation pipeline for robust training:
spatial_transforms = {
"mirror": {
"axes": {"x": 0.5, "y": 0.5, "z": 0.1} # Probability per axis
},
"rotate": {
"axes": {"z": [-45, 45], "y": [-15, 15]} # Angle ranges
},
"transpose": {
"axes": ["x", "y"] # Axes to randomly reorder
}
}
Value Transformations
Built-in preprocessing and augmentation transforms:
from cellmap_data.transforms import (
GaussianNoise, RandomContrast,
RandomGamma, Binarize, NaNtoNum, GaussianBlur
)
# Input preprocessing
raw_transforms = T.Compose([
T.ToDtype(torch.float, scale=True), # Normalize to [0,1]
GaussianNoise(std=0.1), # Add noise
RandomContrast((0.8, 1.2)), # Vary contrast
NaNtoNum({"nan": 0}) # Handle NaN values
])
# Target preprocessing
target_transforms = T.Compose([
Binarize(threshold=0.5), # Convert to binary
T.ToDtype(torch.float32) # Ensure float32
])
Class Relationship Handling
Support for mutually exclusive classes and true negative inference:
# Define class relationships
class_relation_dict = {
"mitochondria": ["cytoplasm", "nucleus"], # Mutually exclusive
"endoplasmic_reticulum": ["mitochondria"], # Cannot overlap
}
dataset = CellMapDataset(
# ... other parameters ...
classes=["mitochondria", "er", "nucleus", "cytoplasm"],
class_relation_dict=class_relation_dict,
# True negatives automatically inferred from relationships
)
Memory-Efficient Large Dataset Handling
For datasets larger than available memory:
# Use subset sampling for large datasets
loader = CellMapDataLoader(
large_dataset,
batch_size=8,
iterations_per_epoch=5000, # Subsample each epoch
weighted_sampler=True
)
# Refresh sampler between epochs
for epoch in range(num_epochs):
loader.refresh() # New random subset
for batch in loader:
# Training code
...
Writing Predictions
Generate predictions and write to disk efficiently:
from cellmap_data import CellMapDatasetWriter
writer = CellMapDatasetWriter(
raw_path="input.zarr",
target_path="predictions.zarr",
classes=["class1", "class2"],
input_arrays=input_arrays,
target_arrays=target_arrays,
target_bounds={"array": {"x": [0, 1000], "y": [0, 1000], "z": [0, 100]}}
)
# Write predictions tile by tile
for idx in range(len(writer)):
inputs = writer[idx]
predictions = model(inputs)
writer[idx] = {"segmentation": predictions}
Data Format Support
Input Formats
- OME-NGFF/Zarr: Primary format with multiscale support and full read/write capabilities
- Local/S3/GCS: Various storage backends via TensorStore
Multiscale Support
Automatic handling of multiscale datasets:
# Automatically selects appropriate scale level
dataset = CellMapDataset(
raw_path="data.zarr", # Contains s0, s1, s2, ... scale levels
target_path="labels.zarr",
# ... other parameters ...
)
# Multiscale input arrays can be specified
input_arrays = {
"raw_4nm": {
"shape": (128, 128, 128),
"scale": (4, 4, 4),
},
"raw_8nm": {
"shape": (64, 64, 64),
"scale": (8, 8, 8),
}
}
Performance Optimization
Memory Management
- Efficient tensor operations with minimal copying
- Automatic GPU memory management
- Streaming data loading for large volumes
Parallel Processing
- Multi-threaded data loading
- CUDA streams for GPU optimization
- Process-safe dataset pickling
Caching Strategy
- Persistent ThreadPoolExecutor for reduced overhead
- Optimized coordinate transformations
- Minimal redundant computations
Use Cases
1. Cell Segmentation Training
# Multi-class cell segmentation
classes = ["cell_boundary", "mitochondria", "nucleus", "er"]
spatial_transforms = {
"mirror": {"axes": {"x": 0.5, "y": 0.5}},
"rotate": {"axes": {"z": [-180, 180]}}
}
dataset = CellMapDataset(
raw_path="em_data.zarr",
target_path="segmentation_labels.zarr",
classes=classes,
input_arrays={"em": {"shape": (128, 128, 128), "scale": (4, 4, 4)}},
target_arrays={"labels": {"shape": (128, 128, 128), "scale": (4, 4, 4)}},
spatial_transforms=spatial_transforms,
is_train=True
)
2. Large-Scale Multi-Dataset Training
# Training across multiple biological samples
datasplit = CellMapDataSplit(
csv_path="multi_sample_split.csv",
classes=organelle_classes,
input_arrays=input_config,
target_arrays=target_config,
spatial_transforms=augmentation_config
)
# Balanced sampling across datasets
train_loader = CellMapDataLoader(
datasplit.train_datasets_combined,
batch_size=16,
weighted_sampler=True,
num_workers=16
)
3. Inference and Prediction Writing
# Generate predictions on new data
writer = CellMapDatasetWriter(
raw_path="new_sample.zarr",
target_path="predictions.zarr",
classes=trained_classes,
input_arrays=inference_config,
target_arrays=output_config,
target_bounds=volume_bounds
)
# Process in tiles
for idx in writer.writer_indices: # Non-overlapping tiles
batch = writer[idx]
with torch.no_grad():
predictions = model(batch["input"])
writer[idx] = {"segmentation": predictions}
Best Practices
Dataset Configuration
- Choose patch sizes that fit comfortably in GPU memory
- Enable padding for datasets smaller than patch size
Training Optimization
- Use weighted sampling for imbalanced datasets
- Configure appropriate number of workers (typically 2x CPU cores)
- Enable CUDA streams for multi-GPU setups
Memory Optimization
- Monitor memory usage with large datasets
- Use iterations_per_epoch for very large datasets
- Refresh samplers between epochs for dataset variety
Debugging
- Start with small patch sizes and single workers
- Use force_has_data=True for testing with empty datasets
- Check dataset.verify() before training
API Reference
For complete API documentation, visit: https://janelia-cellmap.github.io/cellmap-data/
Contributing
We welcome contributions! Please see our contributing guidelines for details on:
- Code style and standards
- Testing requirements
- Documentation expectations
- Pull request process
Citation
If you use CellMap-Data in your research, please cite:
@software{cellmap_data,
title={CellMap-Data: PyTorch Data Loading for Biological Imaging},
author={Rhoades, Jeff and the CellMap Team},
url={https://github.com/janelia-cellmap/cellmap-data},
year={2024}
}
License
This project is licensed under the BSD 3-Clause License - see the LICENSE file for details.
Support
- 📖 Documentation
- 🐛 Issue Tracker
- 💬 Discussions
- 📧 Contact: rhoadesj@hhmi.org
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file cellmap_data-2025.12.24.539.tar.gz.
File metadata
- Download URL: cellmap_data-2025.12.24.539.tar.gz
- Upload date:
- Size: 67.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
eb33fdf904bd1c2331e6a29f9b3c834c68a865a06b86f038d92282fa8c338d4e
|
|
| MD5 |
006be493940fcd9923cae16e16137ada
|
|
| BLAKE2b-256 |
8a4a964b34f42d5f00a14b9091c1ea1d738f86d8a8826cf7e2346cf10f2b39ad
|
File details
Details for the file cellmap_data-2025.12.24.539-py3-none-any.whl.
File metadata
- Download URL: cellmap_data-2025.12.24.539-py3-none-any.whl
- Upload date:
- Size: 65.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d6fc28f9e8dfecd0002039a791c92f1988247490283d917dc0b98f6b0df0de94
|
|
| MD5 |
f441cfff95a5b4fdb5238a0afbd21b1e
|
|
| BLAKE2b-256 |
7eac5bc0dcbac87a40ffda3fe1bc13599d0bfca2f4cbc98d3aa9fda1d9f74a17
|