scDistill: Knowledge distillation for single-cell batch correction with covariate protection
Project description
scDistill
A 2-phase knowledge distillation framework for single-cell RNA-seq batch correction.
Overview
scDistill learns batch correction by distilling knowledge from a teacher method (e.g., Harmony) into a neural network. The key innovations are
- 2-Phase Training — Clean separation of encoder and decoder objectives
- Knowledge Distillation — Learn batch-invariant representations from established methods
- Conditional Decoder — Batch-aware decoding for proper expression reconstruction
Architecture
Phase 1 (Encoder) X → log2(CPM+1) → Encoder → Z Loss = MSE(Z, Z*)
Phase 2 (Decoder) Z → Decoder(Z, batch) → (μ, θ) Loss = NB_NLL(X, μ, θ)
X— Raw count matrix (N cells × G genes)Z*— Teacher's batch-corrected latent representation (from Harmony)Z— Student encoder's output (batch-corrected)μ, θ— Negative Binomial parameters (mean, dispersion)
Theoretical Foundation
The Batch Effect Problem
Single-cell data contains both biological signal and technical batch effects
X_observed = f(biological_signal) + g(batch_effect) + noise
The goal is to obtain batch-corrected expression that preserves biological differences while removing batch artifacts.
Why Knowledge Distillation Works
1. Teacher Creates Batch-Invariant Target
Harmony operates in PCA space and uses soft k-means clustering to align batches
Z* = Harmony(PCA(X), batch_labels)
The resulting Z* satisfies the batch invariance property
P(Z* | batch = b₁) ≈ P(Z* | batch = b₂) for all batches b₁, b₂
2. Encoder Learns the Mapping
The encoder is trained to reproduce Z*
L_encoder = ||Encoder(X) - Z*||²
After convergence, the encoder implicitly learns to remove batch effects
Encoder(X) ≈ Z* → Encoder removes batch information
Key insight — By minimizing distance to Z*, the encoder cannot encode batch-specific information since Z* doesn't contain it.
3. Why Reconstruction Loss Preserves Biology
The decoder is trained with Negative Binomial likelihood
L_decoder = -log P(X | μ, θ) where (μ, θ) = Decoder(Z, batch)
The Negative Binomial distribution is ideal for scRNA-seq count data because it
- Captures overdispersion (variance > mean)
- Handles zero-inflation naturally
- Models count data directly without log-transform
Theorem (informal) — If Decoder maximizes P(X | Z, batch), then Decoder must preserve gene-level biological variation.
Proof sketch
- Z contains only biological information (batch removed by encoder)
- X contains both biological and batch information
- To maximize likelihood of X given Z, Decoder must capture biological signal in Z
- The batch embedding provides batch-specific scaling without encoding batch info into Z
- Biological variation in X must come entirely from Z
Why PCA-based Distillation is Superior
A common misconception is that passing through PCA loses information. In reality, this is intentional denoising and signal purification, not information loss.
1. Z* is the "Core", Not the "Whole" — Manifold Reconstruction by Decoder
Even though Z* is low-dimensional (e.g., 50D), the Decoder's weight parameters (millions of parameters) learn and retain the gene-gene co-expression network needed to reconstruct 20,000 genes from 50 coordinates.
Complementation mechanism
- PCA (Z*) captures the "principal coordinates of cell states"
- Decoder applies biological rules (the manifold) based on those coordinates
- "If a cell is in this state, genes A, B, and C should be expressed at these levels"
Z is a seed, and the Decoder grows it into full expression profiles* — Information not explicitly in Z* is not lost; it is reconstructed through the learned biological manifold.
2. PCA is an Optimal Filter, Not a Bottleneck
The variance-importance mismatch problem
Standard PCA is dominated by highly-expressed genes (high variance), but biologically important DE genes may have low expression. Pearson Residuals PCA solves this
# Pearson Residuals normalize the mean-variance relationship
# Low-expression but biologically meaningful variations are preserved
model = Distiller(adata, batch_key="batch", pca_method="pearson_residuals")
Capturing small perturbations
Even when condition A→B perturbations are small, properly weighted PCA (or Harmony's iterative refinement) captures these perturbations in the top PC axes.
Conclusion — PCA acts as a high-performance noise-canceling filter that strips away technical noise (Poisson noise, etc.) while extracting only the signals necessary for robust DE detection.
3. Information Bottleneck Prevents Overfitting to Noise
scVI and similar end-to-end methods can overfit to technical noise in high-dimensional space. scDistill's PCA bottleneck forces the model to focus on the biological manifold
scVI: X (20,000D) → Encoder → Z (50D) → Decoder → X̂
↑ Can memorize noise in X
scDistill: X → PCA → Z* (50D, denoised) → Encoder → Z → Decoder → X̂
↑ Noise already filtered out
This explains why scDistill achieves higher DE F1 scores — it avoids the "overfitting to noise" trap that reduces DE detection accuracy in end-to-end methods.
Conditional Decoder and Batch Embedding
The decoder takes both Z and batch embedding as input
X̂ = Decoder(Z, Embedding(batch))
This design is critical because
-
Batch information for reconstruction — The decoder needs to know which batch a cell came from to properly reconstruct X since X contains batch effects
-
Clean separation — Biological signal flows through Z, batch effects through the embedding
-
Inference flexibility — At inference time, we can decode to original batch for faithful reconstruction, or decode to a reference batch for batch-corrected expression
Mathematical Formulation
Let X ∈ ℤ₊^(N×G) be the count matrix, B ∈ {1,...,K}^N be batch labels.
Phase 1 — Encoder Training
min_θ Σᵢ ||E_θ(xᵢ) - z*ᵢ||²
where z*ᵢ = Harmony(PCA(X), B)ᵢ
Phase 2 — Decoder Training (encoder frozen)
min_φ Σᵢ -log P_NB(xᵢ | μᵢ, θᵢ)
where (μᵢ, θᵢ) = D_φ(E_θ(xᵢ), e_bᵢ)
e_b = BatchEmbedding(b) ∈ ℝ^d
P_NB(x | μ, θ) = Γ(x+θ)/(Γ(θ)x!) · (θ/(θ+μ))^θ · (μ/(θ+μ))^x
Batch-Corrected Expression
X_corrected = μ = D_φ(E_θ(X), e_ref)
where e_ref is the embedding of a reference batch
Comparison with scVI
scVI and scDistill both use VAE-like architectures with Negative Binomial decoders, but differ fundamentally in how they achieve batch correction.
| Aspect | scVI | scDistill |
|---|---|---|
| Batch correction mechanism | Adversarial/latent space regularization | Knowledge distillation from Harmony |
| Training | End-to-end joint optimization | 2-phase (encoder → decoder) |
| Batch information in Z | Explicitly removed via loss term | Never encoded (teacher target is batch-free) |
| Theoretical guarantee | Relies on optimization balance | Z* is provably batch-invariant |
| Noise handling | Can overfit to high-dimensional noise | PCA bottleneck filters noise |
Why scDistill outperforms scVI for DE analysis
-
Cleaner batch removal — Harmony's iterative algorithm provides stronger batch invariance guarantees than adversarial training, which can suffer from mode collapse or incomplete removal.
-
Biological signal preservation — scVI's joint optimization must balance batch removal against reconstruction. scDistill separates these objectives, allowing the decoder to focus purely on reconstruction.
-
Noise filtering — The PCA bottleneck acts as a denoising step, preventing overfitting to technical artifacts that can corrupt DE estimates.
-
Stability — 2-phase training avoids the optimization difficulties of joint encoder-decoder-discriminator training.
When scVI may be preferred
- When no suitable teacher method exists for the data type
- When end-to-end differentiability is required
- For very large datasets where Harmony becomes slow
Why NOT Cycle Consistency?
An alternative formulation would be
Wrong: ||E(D(Z*, batch)) - Z*||²
This cycle consistency loss is problematic because
- Underdetermined — Z* is low-dimensional (50D), X is high-dimensional (15,000+ genes)
- No anchoring — The decoder can output any X' that encodes back to Z*
- Lost signal — Differential expression information is not constrained
The reconstruction loss anchors outputs to real expression profiles.
Installation
pip install scdistill
# or
uv add scdistill
Quick Start
import scanpy as sc
from scdistill import Distiller
from scdistill.teachers import HarmonyTeacher
# Load data
adata = sc.read_h5ad("data.h5ad")
# Initialize with Harmony teacher
teacher = HarmonyTeacher(theta=2.0)
model = Distiller(
adata,
batch_key="batch",
teacher=teacher,
n_latent=50,
use_batch_conditioning=True,
)
# Train (2-phase)
model.train(
n_epochs_encoder=100,
n_epochs_decoder=100,
lr=1e-3,
)
# Get batch-corrected representations
Z_student = model.latent.student
Z_teacher = model.latent.teacher
X_corrected = model.corrected_expression()
# Differential expression analysis
from scdistill.de import PseudobulkConfig
de_results = model.differential_expression(
groupby="condition",
group1="treatment",
group2="control",
sample_key="sample",
how=PseudobulkConfig(method="deseq2"),
)
Key Features
2-Phase Training
Clean separation ensures each component has a single objective
- Encoder learns batch-invariant representations
- Decoder learns faithful reconstruction
Teacher Flexibility
from scdistill.teachers import HarmonyTeacher
teacher = HarmonyTeacher(
theta=2.0,
max_iter=10,
)
Conditional Decoder
Batch-aware decoding for proper reconstruction
model = Distiller(
adata,
batch_key="batch",
use_batch_conditioning=True,
n_batch_embedding=10,
)
PCA Method Options
# Standard log2(CPM+1) → PCA (default)
model = Distiller(adata, batch_key="batch", pca_method="standard")
# Pearson residuals for sparse count data
model = Distiller(adata, batch_key="batch", pca_method="pearson_residuals")
Differential Expression
Multiple DE methods supported
from scdistill.de import PseudobulkConfig, BayesianConfig
# Pseudobulk with DESeq2 (recommended)
de_results = model.differential_expression(
groupby="condition",
group1="treatment",
group2="control",
sample_key="sample",
how=PseudobulkConfig(method="deseq2"),
)
# Bayesian with MC Dropout
de_results = model.differential_expression(
groupby="condition",
group1="treatment",
group2="control",
how=BayesianConfig(n_samples=100),
)
API Reference
Distiller
Main class for batch correction.
Constructor Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
adata |
AnnData | required | Expression data with counts |
batch_key |
str | required | Column in obs with batch labels |
teacher |
BaseTeacher | HarmonyTeacher() | Teacher method for distillation |
n_latent |
int | 50 | Latent dimension |
n_hidden |
int | 128 | Hidden layer size |
n_layers |
int | 2 | Number of hidden layers |
dropout |
float | 0.1 | Dropout rate |
pca_method |
str | "standard" | "standard" or "pearson_residuals" |
use_batch_conditioning |
bool | True | Enable conditional decoder |
n_batch_embedding |
int | 10 | Batch embedding dimension |
Properties & Methods
| Property/Method | Description |
|---|---|
train(n_epochs_encoder, n_epochs_decoder, ...) |
Run 2-phase training |
latent.student |
Student encoder's batch-corrected latent Z |
latent.teacher |
Teacher's Z* (Harmony-corrected PCA) |
gene_names |
Gene names from training data |
corrected_expression(target_batch) |
Get batch-corrected expression |
differential_expression(...) |
Run DE analysis |
save(path) / load(path) |
Model persistence |
Benchmark Results
On simulated data with 6 scenarios (varying batch effects and DE genes)
| Metric | scDistill | scVI | Raw |
|---|---|---|---|
| DE F1 Score | 0.92 | 0.45 | 0.78 |
| LFC Correlation | 0.98 | 0.85 | 0.95 |
| Batch Mixing (iLISI) | 0.89 | 0.92 | 0.50 |
| Bio Conservation (NMI) | 1.00 | 0.98 | 1.00 |
License
MIT License
Citation
If you use scDistill in your research, please cite
@software{scdistill,
title = {scDistill: Knowledge Distillation for Single-Cell Batch Correction},
author = {Yuya Sato},
year = {2025},
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file scdistill-0.3.1.tar.gz.
File metadata
- Download URL: scdistill-0.3.1.tar.gz
- Upload date:
- Size: 51.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a5c4b170ed0b2b25fe083150f71bdcb34129474b241d6c85958be9c3fe2f3ef8
|
|
| MD5 |
008daee1e8d6f1b887d2fb9c1740286c
|
|
| BLAKE2b-256 |
d0538c7cfeefe0a8e00f437ddba0a34ee687987e37fbce78169f771f2c6c3605
|
File details
Details for the file scdistill-0.3.1-py3-none-any.whl.
File metadata
- Download URL: scdistill-0.3.1-py3-none-any.whl
- Upload date:
- Size: 63.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1dbbde86d1abaafca60526d01dcb655c48d434d1d2ed0931009c75216278521c
|
|
| MD5 |
9a626b6ae24913c44f1c259259491f54
|
|
| BLAKE2b-256 |
124bd81c53aa3385dfafb852900f717c2c9646f0e089c0947d2030b79f5e925c
|