Skip to main content

Scalable Data Loading for Deep Learning on Large-Scale Single-Cell Omics

Project description

scDataset

PyPI version License: MIT arXiv

Scalable Data Loading for Deep Learning on Large-Scale Single-Cell Omics


scDataset architecture

scDataset is a flexible and efficient PyTorch IterableDataset for large-scale single-cell omics datasets. It supports a variety of data formats (e.g., AnnData, HuggingFace Datasets, NumPy arrays) and is designed for high-throughput deep learning workflows. While optimized for single-cell data, it is general-purpose and can be used with any dataset.

Features

  • Flexible Data Source Support: Integrates seamlessly with AnnData, HuggingFace Datasets, NumPy arrays, PyTorch Datasets, and more.
  • Scalable: Handles datasets with billions of samples without loading everything into memory.
  • Efficient Data Loading: Block sampling and batched fetching optimize random access for large datasets.
  • Dynamic Splitting: Split datasets into train/validation/test dynamically, without duplicating data or rewriting files.
  • Custom Hooks: Apply transformations at fetch or batch time via user-defined callbacks.

Installation

Install the latest release from PyPI:

pip install scDataset

Or install the latest development version from GitHub:

pip install git+https://github.com/Kidara/scDataset.git

Usage

Basic Usage with Sampling Strategies

scDataset v0.2.0 uses a strategy-based approach for flexible data sampling:

from scdataset import scDataset, Streaming
from torch.utils.data import DataLoader

# Create dataset with streaming strategy
data = my_data_collection  # Any indexable object (numpy array, AnnData, etc.)
strategy = Streaming()
dataset = scDataset(data, strategy, batch_size=64)
loader = DataLoader(dataset, batch_size=None)  # scDataset handles batching internally

Note: Set batch_size=None in the DataLoader to delegate batching to scDataset.

Sampling Strategies

Sequential Sampling (Streaming)

from scdataset import Streaming

# Simple sequential access
strategy = Streaming()
dataset = scDataset(data, strategy, batch_size=64)

# Sequential with buffer-level shuffling (like Ray Dataset or WebDataset). The buffer size is equal to batch_size * fetch_factor (defined in the scDataset init)
strategy = Streaming(shuffle=True)
dataset = scDataset(data, strategy, batch_size=64, fetch_factor=8)

# Use only a subset of indices
train_indices = [0, 2, 4, 6, 8, ...]  # Your training indices
strategy = Streaming(indices=train_indices)
dataset = scDataset(data, strategy, batch_size=64)

Block Shuffling for Locality

from scdataset import BlockShuffling

# Shuffle blocks while maintaining some data locality
strategy = BlockShuffling(block_size=8)
dataset = scDataset(data, strategy, batch_size=64)

# With subset of indices
strategy = BlockShuffling(block_size=8, indices=train_indices)
dataset = scDataset(data, strategy, batch_size=64)

Weighted Sampling

from scdataset import BlockWeightedSampling

# Uniform weighted sampling
strategy = BlockWeightedSampling(total_size=10000, block_size=16)
dataset = scDataset(data, strategy, batch_size=64)

# Custom weights (e.g., for imbalanced data)
sample_weights = compute_weights(data)  # Your weight computation
strategy = BlockWeightedSampling(
    weights=sample_weights, 
    total_size=5000,
    block_size=16
)
dataset = scDataset(data, strategy, batch_size=64)

Automatic Class Balancing

from scdataset import ClassBalancedSampling

# Automatically balance classes from labels
cell_types = ['T_cell', 'B_cell', 'NK_cell', ...]  # Your class labels
strategy = ClassBalancedSampling(cell_types, total_size=8000)
dataset = scDataset(data, strategy, batch_size=64)

Multi-Modal Data with MultiIndexable

Handle multiple related data modalities that need to be indexed together:

from scdataset import MultiIndexable, Streaming

# Group multiple data modalities
multi_data = MultiIndexable(
    genes=gene_expression_data,    # Shape: (n_cells, n_genes)
    proteins=protein_data,         # Shape: (n_cells, n_proteins)  
    metadata=cell_metadata         # Shape: (n_cells, n_features)
)

# Use with any sampling strategy
strategy = Streaming()
dataset = scDataset(multi_data, strategy, batch_size=64)

for batch in dataset:
    genes = batch['genes']       # Gene expression for this batch
    proteins = batch['proteins'] # Corresponding protein data
    metadata = batch['metadata'] # Corresponding metadata

Performance Optimization

Configure fetch_factor to fetch multiple batches worth of data at once:

strategy = BlockShuffling(block_size=16)
dataset = scDataset(
    data, 
    strategy, 
    batch_size=64,
    fetch_factor=8  # Fetch 8*64=512 samples at once
)
loader = DataLoader(
    dataset,
    batch_size=None,
    num_workers=4,
    prefetch_factor=9  # fetch_factor + 1
)

We recommend setting prefetch_factor to fetch_factor + 1 for efficient data loading. For parameter details, see the original paper.

Custom Transforms and Callbacks

Apply custom transformations at fetch or batch time using the new callback system:

Transform Overview

  • fetch_callback(collection, indices):
    Customizes how samples are fetched from the underlying data collection.
    Use this if your collection does not support batched indexing or requires special access logic.

    • Input: the data collection and an array of indices
    • Output: the fetched data
  • fetch_transform(fetched_data):
    Transforms each fetched chunk (e.g., sparse-to-dense conversion, normalization).

    • Input: the fetched data
    • Output: the transformed data
  • batch_callback(fetched_data, batch_indices):
    Selects or arranges a minibatch from the fetched/transformed data.

    • Input: the fetched/transformed data and a list of batch indices within the chunk
    • Output: the batch to yield
  • batch_transform(batch):
    Applies final processing to each batch before yielding (e.g., collation, augmentation).

    • Input: the batch
    • Output: the processed batch
from scdataset import scDataset, Streaming

def fetch_transform(chunk):
    # Example: convert sparse to dense, normalization, etc.
    # Applied to entire fetched chunk
    return chunk.toarray() if hasattr(chunk, 'toarray') else chunk

def batch_transform(batch):
    # Example: batch-level augmentation or tensor conversion
    import torch
    return torch.from_numpy(batch).float()

strategy = Streaming()
dataset = scDataset(
    data,
    strategy,
    batch_size=64,
    fetch_transform=fetch_transform,
    batch_transform=batch_transform
)

Complete Example with Multiple Strategies

from scdataset import scDataset, BlockShuffling, Streaming
from torch.utils.data import DataLoader
import numpy as np

# Your data
data = my_data_collection
train_indices = np.arange(0, 8000)
val_indices = np.arange(8000, 10000)

# Training with block shuffling
train_strategy = BlockShuffling(block_size=32, indices=train_indices)
train_dataset = scDataset(
    data,
    train_strategy,
    batch_size=64,
    fetch_factor=8
)

train_loader = DataLoader(
    train_dataset,
    batch_size=None,
    num_workers=4,
    prefetch_factor=9
)

# Validation with streaming (deterministic)
val_strategy = Streaming(indices=val_indices)
val_dataset = scDataset(
    data,
    val_strategy,
    batch_size=64,
    fetch_factor=8
)

val_loader = DataLoader(
    val_dataset,
    batch_size=None,
    num_workers=4,
    prefetch_factor=9
)

# Training loop
for epoch in range(num_epochs):
    # Training
    for batch in train_loader:
        # Training code here
        pass
    
    # Validation  
    for batch in val_loader:
        # Validation code here
        pass

Citing

If you use scDataset in your research, please cite the following paper:

@article{scdataset2025,
  title={scDataset: Scalable Data Loading for Deep Learning on Large-Scale Single-Cell Omics},
  author={D'Ascenzo, Davide and Cultrera di Montesano, Sebastiano},
  journal={arXiv:2506.01883},
  year={2025}
}

Migration from v0.1.x to v0.2.0

scDataset v0.2.0 introduces breaking changes with a new strategy-based API. Here's how to migrate your code:

Old v0.1.x API

# v0.1.x - No longer supported
from scdataset import scDataset

dataset = scDataset(data, batch_size=64, block_size=8, fetch_factor=4)
dataset.subset(train_indices)
dataset.set_mode('train')

New v0.2.0 API

# v0.2.0 - Strategy-based approach
from scdataset import scDataset, BlockShuffling, Streaming

# Training with shuffling
train_strategy = BlockShuffling(block_size=8, indices=train_indices)
train_dataset = scDataset(data, train_strategy, batch_size=64, fetch_factor=4)

# Evaluation with streaming
val_strategy = Streaming(indices=val_indices)
val_dataset = scDataset(data, val_strategy, batch_size=64, fetch_factor=4)

Key Changes:

  • Required strategy parameter: Must provide a SamplingStrategy instance
  • No more subset() and set_mode(): Use strategy indices parameter and different strategy types
  • Create separate datasets: For different splits instead of modifying a single instance
  • New import: Import specific strategies like Streaming, BlockShuffling, etc.

License

This project is licensed under the MIT License.

Contributing

Contributions are welcome! Please open issues or pull requests on GitHub.


Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

scdataset-0.2.0.tar.gz (19.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

scdataset-0.2.0-py3-none-any.whl (21.3 kB view details)

Uploaded Python 3

File details

Details for the file scdataset-0.2.0.tar.gz.

File metadata

  • Download URL: scdataset-0.2.0.tar.gz
  • Upload date:
  • Size: 19.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for scdataset-0.2.0.tar.gz
Algorithm Hash digest
SHA256 d01e32669f289c661d8112370e42d1324067e5056e0f9bba18c013aa5496ba7d
MD5 49f84a3692b7c17fea8147b31636b791
BLAKE2b-256 8b10d32c89fe5f3234a711fe3ecc16d426744adfb1a7a34cc2efd7e814eefce2

See more details on using hashes here.

File details

Details for the file scdataset-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: scdataset-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 21.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for scdataset-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 3f6e10421ad284a0aa7c3949ff8aae559116311a3b2643e0a45341c9fad0211f
MD5 930f11de3434c6da7a0866b86713452b
BLAKE2b-256 328a580def2b713cb09e996b0a4a10a61f0f5c1646795c68c55d522d9be1e7e5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page