Skip to main content

scDistill: Knowledge distillation for single-cell batch correction with covariate protection

Project description

scDistill

A 2-phase knowledge distillation framework for single-cell RNA-seq batch correction.

Overview

scDistill learns batch correction by distilling knowledge from a teacher method (e.g., Harmony) into a neural network. The key innovations are

  1. 2-Phase Training — Clean separation of encoder and decoder objectives
  2. Knowledge Distillation — Learn batch-invariant representations from established methods
  3. Conditional Decoder — Batch-aware decoding for proper expression reconstruction

Architecture

Phase 1 (Encoder)  X → log2(CPM+1) → Encoder → Z      Loss = MSE(Z, Z*)
Phase 2 (Decoder)  Z → Decoder(Z, batch) → (μ, θ)    Loss = NB_NLL(X, μ, θ)
  • X — Raw count matrix (N cells × G genes)
  • Z* — Teacher's batch-corrected latent representation (from Harmony)
  • Z — Student encoder's output (batch-corrected)
  • μ, θ — Negative Binomial parameters (mean, dispersion)

Theoretical Foundation

The Batch Effect Problem

Single-cell data contains both biological signal and technical batch effects

X_observed = f(biological_signal) + g(batch_effect) + noise

The goal is to obtain batch-corrected expression that preserves biological differences while removing batch artifacts.

Why Knowledge Distillation Works

1. Teacher Creates Batch-Invariant Target

Harmony operates in PCA space and uses soft k-means clustering to align batches

Z* = Harmony(PCA(X), batch_labels)

The resulting Z* satisfies the batch invariance property

P(Z* | batch = b₁) ≈ P(Z* | batch = b₂)  for all batches b₁, b₂

2. Encoder Learns the Mapping

The encoder is trained to reproduce Z*

L_encoder = ||Encoder(X) - Z*||²

After convergence, the encoder implicitly learns to remove batch effects

Encoder(X) ≈ Z*  →  Encoder removes batch information

Key insight — By minimizing distance to Z*, the encoder cannot encode batch-specific information since Z* doesn't contain it.

3. Why Reconstruction Loss Preserves Biology

The decoder is trained with Negative Binomial likelihood

L_decoder = -log P(X | μ, θ)  where (μ, θ) = Decoder(Z, batch)

The Negative Binomial distribution is ideal for scRNA-seq count data because it

  • Captures overdispersion (variance > mean)
  • Handles zero-inflation naturally
  • Models count data directly without log-transform

Theorem (informal) — If Decoder maximizes P(X | Z, batch), then Decoder must preserve gene-level biological variation.

Proof sketch

  • Z contains only biological information (batch removed by encoder)
  • X contains both biological and batch information
  • To maximize likelihood of X given Z, Decoder must capture biological signal in Z
  • The batch embedding provides batch-specific scaling without encoding batch info into Z
  • Biological variation in X must come entirely from Z

Why PCA-based Distillation is Superior

A common misconception is that passing through PCA loses information. In reality, this is intentional denoising and signal purification, not information loss.

1. Z* is the "Core", Not the "Whole" — Manifold Reconstruction by Decoder

Even though Z* is low-dimensional (e.g., 50D), the Decoder's weight parameters (millions of parameters) learn and retain the gene-gene co-expression network needed to reconstruct 20,000 genes from 50 coordinates.

Complementation mechanism

  • PCA (Z*) captures the "principal coordinates of cell states"
  • Decoder applies biological rules (the manifold) based on those coordinates
  • "If a cell is in this state, genes A, B, and C should be expressed at these levels"

Z is a seed, and the Decoder grows it into full expression profiles* — Information not explicitly in Z* is not lost; it is reconstructed through the learned biological manifold.

2. PCA is an Optimal Filter, Not a Bottleneck

The variance-importance mismatch problem

Standard PCA is dominated by highly-expressed genes (high variance), but biologically important DE genes may have low expression. Pearson Residuals PCA solves this

# Pearson Residuals normalize the mean-variance relationship
# Low-expression but biologically meaningful variations are preserved
model = Distiller(adata, batch_key="batch", pca_method="pearson_residuals")

Capturing small perturbations

Even when condition A→B perturbations are small, properly weighted PCA (or Harmony's iterative refinement) captures these perturbations in the top PC axes.

Conclusion — PCA acts as a high-performance noise-canceling filter that strips away technical noise (Poisson noise, etc.) while extracting only the signals necessary for robust DE detection.

3. Information Bottleneck Prevents Overfitting to Noise

scVI and similar end-to-end methods can overfit to technical noise in high-dimensional space. scDistill's PCA bottleneck forces the model to focus on the biological manifold

scVI:      X (20,000D) → Encoder → Z (50D) → Decoder → X̂
           ↑ Can memorize noise in X

scDistill: X → PCA → Z* (50D, denoised) → Encoder → Z → Decoder → X̂
           ↑ Noise already filtered out

This explains why scDistill achieves higher DE F1 scores — it avoids the "overfitting to noise" trap that reduces DE detection accuracy in end-to-end methods.

Conditional Decoder and Batch Embedding

The decoder takes both Z and batch embedding as input

X̂ = Decoder(Z, Embedding(batch))

This design is critical because

  1. Batch information for reconstruction — The decoder needs to know which batch a cell came from to properly reconstruct X since X contains batch effects

  2. Clean separation — Biological signal flows through Z, batch effects through the embedding

  3. Inference flexibility — At inference time, we can decode to original batch for faithful reconstruction, or decode to a reference batch for batch-corrected expression

Mathematical Formulation

Let X ∈ ℤ₊^(N×G) be the count matrix, B ∈ {1,...,K}^N be batch labels.

Phase 1 — Encoder Training

min_θ  Σᵢ ||E_θ(xᵢ) - z*ᵢ||²

where z*ᵢ = Harmony(PCA(X), B)ᵢ

Phase 2 — Decoder Training (encoder frozen)

min_φ  Σᵢ -log P_NB(xᵢ | μᵢ, θᵢ)

where (μᵢ, θᵢ) = D_φ(E_θ(xᵢ), e_bᵢ)
      e_b = BatchEmbedding(b) ∈ ℝ^d

P_NB(x | μ, θ) = Γ(x+θ)/(Γ(θ)x!) · (θ/(θ+μ))^θ · (μ/(θ+μ))^x

Batch-Corrected Expression

X_corrected = μ = D_φ(E_θ(X), e_ref)

where e_ref is the embedding of a reference batch

Comparison with scVI

scVI and scDistill both use VAE-like architectures with Negative Binomial decoders, but differ fundamentally in how they achieve batch correction.

Aspect scVI scDistill
Batch correction mechanism Adversarial/latent space regularization Knowledge distillation from Harmony
Training End-to-end joint optimization 2-phase (encoder → decoder)
Batch information in Z Explicitly removed via loss term Never encoded (teacher target is batch-free)
Theoretical guarantee Relies on optimization balance Z* is provably batch-invariant
Noise handling Can overfit to high-dimensional noise PCA bottleneck filters noise

Why scDistill outperforms scVI for DE analysis

  1. Cleaner batch removal — Harmony's iterative algorithm provides stronger batch invariance guarantees than adversarial training, which can suffer from mode collapse or incomplete removal.

  2. Biological signal preservation — scVI's joint optimization must balance batch removal against reconstruction. scDistill separates these objectives, allowing the decoder to focus purely on reconstruction.

  3. Noise filtering — The PCA bottleneck acts as a denoising step, preventing overfitting to technical artifacts that can corrupt DE estimates.

  4. Stability — 2-phase training avoids the optimization difficulties of joint encoder-decoder-discriminator training.

When scVI may be preferred

  • When no suitable teacher method exists for the data type
  • When end-to-end differentiability is required
  • For very large datasets where Harmony becomes slow

Why NOT Cycle Consistency?

An alternative formulation would be

Wrong: ||E(D(Z*, batch)) - Z*||²

This cycle consistency loss is problematic because

  1. Underdetermined — Z* is low-dimensional (50D), X is high-dimensional (15,000+ genes)
  2. No anchoring — The decoder can output any X' that encodes back to Z*
  3. Lost signal — Differential expression information is not constrained

The reconstruction loss anchors outputs to real expression profiles.

Installation

pip install scdistill
# or
uv add scdistill

Quick Start

import scanpy as sc
from scdistill import Distiller
from scdistill.teachers import HarmonyTeacher

# Load data
adata = sc.read_h5ad("data.h5ad")

# Initialize with Harmony teacher
teacher = HarmonyTeacher(theta=2.0)
model = Distiller(
    adata,
    batch_key="batch",
    teacher=teacher,
    n_latent=50,
    use_batch_conditioning=True,
)

# Train (2-phase)
model.train(
    n_epochs_encoder=100,
    n_epochs_decoder=100,
    lr=1e-3,
)

# Get batch-corrected representations
Z = model.get_latent_representation()
X_corrected = model.get_corrected_expression()

# Differential expression analysis
from scdistill.de import PseudobulkConfig

de_results = model.differential_expression(
    groupby="condition",
    group1="treatment",
    group2="control",
    sample_key="sample",
    how=PseudobulkConfig(method="deseq2"),
)

Key Features

2-Phase Training

Clean separation ensures each component has a single objective

  • Encoder learns batch-invariant representations
  • Decoder learns faithful reconstruction

Teacher Flexibility

from scdistill.teachers import HarmonyTeacher

teacher = HarmonyTeacher(
    theta=2.0,
    max_iter=10,
)

Conditional Decoder

Batch-aware decoding for proper reconstruction

model = Distiller(
    adata,
    batch_key="batch",
    use_batch_conditioning=True,
    n_batch_embedding=10,
)

PCA Method Options

# Standard log2(CPM+1) → PCA (default)
model = Distiller(adata, batch_key="batch", pca_method="standard")

# Pearson residuals for sparse count data
model = Distiller(adata, batch_key="batch", pca_method="pearson_residuals")

Differential Expression

Multiple DE methods supported

from scdistill.de import PseudobulkConfig, BayesianConfig

# Pseudobulk with DESeq2 (recommended)
de_results = model.differential_expression(
    groupby="condition",
    group1="treatment",
    group2="control",
    sample_key="sample",
    how=PseudobulkConfig(method="deseq2"),
)

# Bayesian with MC Dropout
de_results = model.differential_expression(
    groupby="condition",
    group1="treatment",
    group2="control",
    how=BayesianConfig(n_samples=100),
)

API Reference

Distiller

Main class for batch correction.

Constructor Parameters

Parameter Type Default Description
adata AnnData required Expression data with counts
batch_key str required Column in obs with batch labels
teacher BaseTeacher HarmonyTeacher() Teacher method for distillation
n_latent int 50 Latent dimension
n_hidden int 128 Hidden layer size
n_layers int 2 Number of hidden layers
dropout float 0.1 Dropout rate
pca_method str "standard" "standard" or "pearson_residuals"
use_batch_conditioning bool True Enable conditional decoder
n_batch_embedding int 10 Batch embedding dimension

Methods

Method Description
train(n_epochs_encoder, n_epochs_decoder, ...) Run 2-phase training
get_latent_representation() Get batch-corrected latent Z
get_teacher_representation() Get teacher's Z*
get_corrected_expression(reference_batch) Get batch-corrected expression
differential_expression(...) Run DE analysis
save(path) / load(path) Model persistence

Benchmark Results

On simulated data with 6 scenarios (varying batch effects and DE genes)

Metric scDistill scVI Raw
DE F1 Score 0.92 0.45 0.78
LFC Correlation 0.98 0.85 0.95
Batch Mixing (iLISI) 0.89 0.92 0.50
Bio Conservation (NMI) 1.00 0.98 1.00

License

MIT License

Citation

If you use scDistill in your research, please cite

@software{scdistill,
  title = {scDistill: Knowledge Distillation for Single-Cell Batch Correction},
  author = {Yuya Sato},
  year = {2025},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

scdistill-0.3.0.tar.gz (49.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

scdistill-0.3.0-py3-none-any.whl (60.9 kB view details)

Uploaded Python 3

File details

Details for the file scdistill-0.3.0.tar.gz.

File metadata

  • Download URL: scdistill-0.3.0.tar.gz
  • Upload date:
  • Size: 49.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.13

File hashes

Hashes for scdistill-0.3.0.tar.gz
Algorithm Hash digest
SHA256 d7e9fb871b31ddde4b04073e4da0544e47f9534c131d51b897694eed010f1230
MD5 a52f4dbdb3d56abd09065f56787322af
BLAKE2b-256 a1a8e54645719aeea151f98599220ce92a12c428940cd7ea97b451c954a55f3f

See more details on using hashes here.

File details

Details for the file scdistill-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: scdistill-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 60.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.13

File hashes

Hashes for scdistill-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b5fb2c49e288b0b4220e28e82fea0c3c7f68486d76cd202906992d0bbda0571c
MD5 14d376bc4ee4849b6bbd97c5dddee92b
BLAKE2b-256 9a155ae9266c0980a27a4403cb5418438689c52240ecb4a1bf9cd00330dc0c65

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page