A deep learning framework for batch effect correction in biological data
Project description
BioBatchNet
BioBatchNet is a deep learning framework for batch effect correction in biological data, supporting both single-cell RNA-seq (scRNA-seq) and Imaging Mass Cytometry (IMC) data.
Installation
From PyPI (Recommended)
pip install biobatchnet
From Source
git clone https://github.com/Manchester-HealthAI/BioBatchNet
cd BioBatchNet
pip install -e .
Prerequisites
conda create -n biobatchnet python=3.10
conda activate biobatchnet
pip install torch numpy pandas anndata
Quick Start
Basic Usage
import BioBatchNet
from BioBatchNet import correct_batch_effects
import anndata as ad
# Load your data
adata = ad.read_h5ad('your_data.h5ad')
X = adata.X
batch_labels = adata.obs['batch'].values
# Correct batch effects; returns (bio_embeddings, batch_embeddings)
bio_embeddings, batch_embeddings = correct_batch_effects(
data=X,
batch_info=batch_labels,
data_type='imc', # or 'scrna' for single-cell RNA-seq
epochs=100
)
# Add embeddings to AnnData (recommended usage)
adata.obsm['X_biobatchnet'] = bio_embeddings
adata.obsm['X_batch'] = batch_embeddings
# Checkpoint storage: by default, training artefacts are written to a temporary
# directory and deleted automatically. To keep checkpoints, pass
# save_dir='path/to/output'.
Advanced Usage
Custom Loss Weights
For IMC data with specific batch characteristics:
# Define custom loss weights
loss_weights = {
'recon_loss': 10, # Reconstruction loss weight
'discriminator': 0.3, # Adversarial loss weight
'classifier': 1, # Batch classifier loss weight
'mmd_loss_1': 0, # MMD loss weight
'kl_loss_1': 0.005, # KL divergence weight for bio encoder
'kl_loss_2': 0.1, # KL divergence weight for batch encoder
'ortho_loss': 0.01 # Orthogonality loss weight
}
bio_embeddings, batch_embeddings = correct_batch_effects(
data=X,
batch_info=batch_labels,
data_type='imc',
latent_dim=20,
epochs=100,
loss_weights=loss_weights
)
Custom Architecture Parameters
Fine-tune the neural network architecture:
bio_embeddings, batch_embeddings = correct_batch_effects(
data=X,
batch_info=batch_labels,
data_type='imc',
latent_dim=20,
epochs=100,
bio_encoder_hidden_layers=[500, 2000, 2000],
batch_encoder_hidden_layers=[500],
decoder_hidden_layers=[2000, 2000, 500],
batch_classifier_layers_power=[500, 2000, 2000],
batch_classifier_layers_weak=[128]
)
Direct Model Access
For more control, use the models directly:
from BioBatchNet import IMCVAE, GeneVAE
# For IMC data
model = IMCVAE(
in_sz=40, # Number of features
out_sz=40, # Output dimension
num_batch=4, # Number of batches
latent_sz=20, # Latent dimension
bio_encoder_hidden_layers=[500, 2000, 2000],
batch_encoder_hidden_layers=[500],
decoder_hidden_layers=[2000, 2000, 500],
batch_classifier_layers_power=[500, 2000, 2000],
batch_classifier_layers_weak=[128]
)
# Train the model
model.fit(data, batch_labels, epochs=100, lr=1e-3)
# Get biological embeddings (batch-corrected representations)
bio_embeddings = model.get_bio_embeddings(data)
# Get corrected data
corrected_data = model.correct_batch_effects(data)
Data Formats
Config-Driven Training Scripts
Scripts such as BioBatchNet/imc.py and BioBatchNet/scrna.py reproduce our research training pipeline. They expect the original datasets to be placed under Data/... (see config YAMLs for paths) and are not included in the pip distribution. For typical usage please prefer the Python API or model classes above.
Input Data Requirements
-
Data Matrix:
- NumPy array or PyTorch tensor
- Shape: (n_cells, n_features)
- For scRNA-seq: gene expression matrix
- For IMC: protein expression matrix
-
Batch Information:
- NumPy array or list
- Can be string labels (e.g., ['Patient1', 'Patient2']) or numeric
- Length must match number of cells
Output
- Corrected Data: Same shape as input, with batch effects removed
- Bio Embeddings: Low-dimensional biological representations (n_cells, latent_dim)
Recommended Parameters
For IMC Data
# Small dataset (< 10 batches)
loss_weights = {
'recon_loss': 10,
'discriminator': 0.3,
'classifier': 1,
'mmd_loss_1': 0,
'kl_loss_1': 0.005,
'kl_loss_2': 0.1,
'ortho_loss': 0.01
}
# Large dataset (> 30 batches)
loss_weights = {
'recon_loss': 10,
'discriminator': 0.1,
'classifier': 1,
'mmd_loss_1': 0.01,
'kl_loss_1': 0.0,
'kl_loss_2': 0.1,
'ortho_loss': 0.01
}
For scRNA-seq Data
loss_weights = {
'recon_loss': 10,
'discriminator': 0.04,
'classifier': 1,
'kl_loss_1': 1e-7,
'kl_loss_2': 0.01,
'ortho_loss': 0.0002,
'mmd_loss_1': 0,
'kl_loss_size': 0.002
}
Example Workflow
import BioBatchNet
import anndata as ad
import numpy as np
from BioBatchNet import correct_batch_effects
# 1. Load data
adata = ad.read_h5ad('IMMUcan_batch.h5ad')
print(f"Data shape: {adata.shape}")
print(f"Batches: {adata.obs['BATCH'].unique()}")
# 2. Prepare data
X = adata.X
batch_labels = adata.obs['BATCH'].values
# Convert categorical to numpy array if needed
if hasattr(batch_labels, 'to_numpy'):
batch_labels = batch_labels.to_numpy()
# 3. Correct batch effects
corrected = correct_batch_effects(
data=X,
batch_info=batch_labels,
data_type='imc',
latent_dim=20,
epochs=100
)
# 4. Store results
adata.layers['corrected'] = corrected
# 5. Optional: Get embeddings for visualization
from BioBatchNet import IMCVAE
model = IMCVAE(
in_sz=X.shape[1],
out_sz=X.shape[1],
num_batch=len(np.unique(batch_labels)),
latent_sz=20
)
model.fit(X, batch_labels, epochs=100)
embeddings = model.get_bio_embeddings(X)
adata.obsm['X_biobatchnet'] = embeddings
# 6. Visualize results (using scanpy)
import scanpy as sc
sc.pp.neighbors(adata, use_rep='X_biobatchnet')
sc.tl.umap(adata)
sc.pl.umap(adata, color=['BATCH', 'celltype'])
Tips and Best Practices
-
Data Preprocessing:
- Ensure data is properly normalized before batch correction
- For scRNA-seq: consider log-transformation
- For IMC: consider arcsinh transformation
-
Batch Size:
- Default batch size is 256
- Reduce if encountering memory issues
- Increase for faster training with sufficient memory
-
Number of Epochs:
- Start with 100 epochs for initial testing
- Increase to 200-500 for final results
- Monitor loss convergence
-
Latent Dimension:
- Default is 20
- Increase for complex datasets with many cell types
- Decrease for simpler datasets
-
Post-processing:
- Output may need scaling/normalization
- Consider clipping extreme values
- Validate biological signals are preserved
Troubleshooting
Memory Issues
# Reduce batch size
corrected = correct_batch_effects(
data=X,
batch_info=batch_labels,
batch_size=64 # Smaller batch size
)
Convergence Issues
# Adjust learning rate
model.fit(data, batch_labels, lr=1e-4) # Lower learning rate
Large Output Range
# Post-process corrected data
corrected = correct_batch_effects(data=X, batch_info=batch_labels)
corrected = np.clip(corrected, 0, np.percentile(corrected, 99))
Features
- Multi-modal Support: Works with both scRNA-seq and IMC data
- Easy-to-Use API: One-line batch correction function
- Flexible Architecture: Customizable neural network parameters
- Adaptive Loss Weights: Automatically adjusts based on dataset characteristics
- Comprehensive Documentation: Detailed usage examples and best practices
Citation
If you use BioBatchNet in your research, please cite:
[Citation information to be added]
Support
For issues and questions:
- GitHub Issues: https://github.com/Manchester-HealthAI/BioBatchNet/issues
- PyPI Package: https://pypi.org/project/biobatchnet/
License
MIT License
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file biobatchnet-0.1.5.tar.gz.
File metadata
- Download URL: biobatchnet-0.1.5.tar.gz
- Upload date:
- Size: 28.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7d3948adec02fbd6a17e15af3c2cfca979c2e3de3f05f9eeba02a1fa873bbc02
|
|
| MD5 |
7211e68bc7cc428cc312b2dca8e7dfcd
|
|
| BLAKE2b-256 |
c56d760db6eba758e38444f0bea32be51ddf9be936cac46b3f0b3bac29f67d73
|
File details
Details for the file biobatchnet-0.1.5-py3-none-any.whl.
File metadata
- Download URL: biobatchnet-0.1.5-py3-none-any.whl
- Upload date:
- Size: 34.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
125229a42e1c93df8a51b06281ce6bfe3de9fcf8a6c431f90d83c64551ee0d49
|
|
| MD5 |
1f313f3446098d8bb904bbe89c36c4c6
|
|
| BLAKE2b-256 |
7310a76d1d580b8dc2763148aedb52a00ef5b60c62a02b8d8c3f1d32726fd3c6
|