Skip to main content

GPU-accelerated Robust Cell Type Decomposition (RCTD) for spatial transcriptomics

Project description

rctd-py

Tests PyPI Python License: GPL v3 codecov

JAX-accelerated Robust Cell Type Decomposition for spatial transcriptomics.

A Python reimplementation of the spacexr RCTD algorithm (Cable et al., 2022) with GPU acceleration via JAX. Deconvolves spatial transcriptomics data (Visium, Xenium, MERFISH, Slide-seq, etc.) into cell type proportions using a scRNA-seq reference.

Installation

uv pip install rctd-py

Or with standard pip:

pip install rctd-py

With CUDA support (GPU acceleration):

uv pip install "rctd-py[cuda]"

For development:

git clone https://github.com/p-gueguen/rctd-py.git
cd rctd-py
uv pip install -e ".[dev]"

Dependencies: jax, jaxlib, numpy, scipy, anndata

Quick Start

from rctd import Reference, run_rctd
import anndata

# 1. Load reference scRNA-seq data
ref_adata = anndata.read_h5ad("reference.h5ad")
reference = Reference(ref_adata, cell_type_col="cell_type")

# 2. Load spatial data
spatial = anndata.read_h5ad("spatial.h5ad")

# 3. Run RCTD
result = run_rctd(spatial, reference, mode="doublet")

The run_rctd function handles the full pipeline: gene intersection, platform effect normalization, sigma estimation, and per-pixel deconvolution.

See the tutorial notebook or the rendered tutorial for a complete walkthrough with synthetic data.

Deconvolution Modes

RCTD supports three modes, selected via the mode parameter:

Mode Description Use case
full Estimates weights for all K cell types per pixel using constrained IRWLS. Continuous mixtures, Visium
doublet Classifies each pixel as singlet or doublet, then estimates the top 1--2 cell type weights. Reports spot_class (singlet, doublet_certain, doublet_uncertain, reject). Slide-seq, sparse spatial data
multi Greedy forward selection of up to 4 cell types per pixel, adding types while the likelihood improves. Dense spatial platforms (Xenium, MERFISH)

Benchmarks

End-to-end (Xenium, 58k pixels, doublet mode)

Full pipeline on a 58k-pixel Xenium dataset (380 genes, 45 cell types):

Backend Sigma estimation Deconvolution Total
R spacexr (8 CPU cores) ~49 min ~2 min ~51 min
rctd-py — JAX GPU (Blackwell B200) ~3 min ~36s ~3.5 min
rctd-py — JAX GPU (L40S) ~3 min ~55 min* ~58 min

*The L40S is a rendering/inference GPU (GDDR6, 864 GB/s) rather than an HPC card. Its memory bandwidth is ~9× lower than the B200 (HBM3e, ~8 TB/s), making the memory-bound doublet IRWLS loop much slower. On HPC-class GPUs (H100, A100, B200) the deconvolution step completes in under 1 minute.

Sigma estimation uses a Poisson-Lognormal model with cubic spline interpolation. After optimisation (cached matrix inverse, precomputed spline coefficients, vmapped JAX evaluation), sigma drops from ~66 min to ~3 min — a ~23× speedup — and results are numerically identical.

IRWLS solver only

Solver throughput measured on the spacexr vignette dataset (71 pixels, 313 genes, 19 cell types), scaled to larger pixel counts:

Backend Pixels/sec Speedup vs R
R spacexr (single-core) ~62 1x
JAX CPU (16 threads) ~374 6x
JAX GPU (L40S) ~3,900 63x
JAX GPU (Blackwell B200) ~4,450 72x

GPU throughput saturates at ~3,900 pixels/sec on L40S at 7k+ pixels. JAX compilation overhead dominates at small pixel counts.

Validation

Validated against R spacexr on a Xenium dataset (45 cell types, 380 genes, ~58k filtered pixels):

Metric Value
Dominant type agreement 99.7%
Median per-pixel weight correlation 1.0000
Mean per-pixel weight correlation 0.9998
Pixels with correlation > 0.8 99.98%

Both implementations use identical parameters: UMI_min=20, doublet mode, constrain=FALSE for full-mode weight estimation. See the rendered tutorial for a walkthrough on synthetic data.

API Overview

run_rctd(spatial, reference, mode, config, batch_size)

End-to-end pipeline. Takes an AnnData spatial object and a Reference, returns a typed result (FullResult, DoubletResult, or MultiResult).

Reference(adata, cell_type_col, cell_min, n_max_cells, min_UMI)

Constructs cell type mean expression profiles from a scRNA-seq AnnData. Filters cell types below cell_min cells, caps per-type cells at n_max_cells, and removes cells below min_UMI.

RCTD(spatial, reference, config)

Stateful class for step-by-step control. Call fit_platform_effects() to normalize, then use run_full_mode, run_doublet_mode, or run_multi_mode directly.

RCTDConfig

Named tuple with all algorithm parameters. Key fields:

Parameter Default Description
UMI_min 100 Minimum UMI count per spatial pixel
UMI_min_sigma 300 Minimum UMI for sigma estimation
N_fit 1000 Number of pixels for sigma fitting
MAX_MULTI_TYPES 4 Maximum cell types in multi mode
CONFIDENCE_THRESHOLD 5.0 Singlet confidence threshold (doublet mode)
DOUBLET_THRESHOLD 20.0 Doublet certainty threshold
max_iter 50 IRWLS maximum iterations

Result Types

  • FullResult: weights (N x K), cell_type_names, converged
  • DoubletResult: weights (N x K), weights_doublet (N x 2), spot_class, first_type, second_type, first_class, second_class, min_score, singlet_score, cell_type_names
  • MultiResult: weights (N x K), sub_weights, cell_type_indices, n_types, conf_list, min_score, cell_type_names

GPU Usage

JAX automatically detects available GPUs. To enable CUDA:

pip install "rctd-py[cuda]"

This installs jax[cuda12]. Verify GPU detection:

import jax
print(jax.devices())  # [CudaDevice(id=0)]

Use the batch_size parameter in run_rctd to control GPU memory usage. The default (10,000 pixels per batch) works well for GPUs with 24+ GB VRAM.

Project Structure

src/rctd/
  __init__.py        # Public API exports
  _types.py          # RCTDConfig, FullResult, DoubletResult, MultiResult
  _reference.py      # Reference class (profile computation, DE gene selection)
  _rctd.py           # RCTD class and run_rctd pipeline
  _normalize.py      # Platform effect estimation (fit_bulk)
  _sigma.py          # Sigma (overdispersion) estimation
  _likelihood.py     # Poisson-Lognormal model, Q-matrix interpolation
  _irwls.py          # Batched IRWLS solver (JAX jit + vmap)
  _simplex.py        # Simplex projection for constrained optimization
  _full.py           # Full mode deconvolution
  _doublet.py        # Doublet mode deconvolution
  _multi.py          # Multi mode deconvolution

Contributing

Contributions are welcome! Please open an issue to discuss proposed changes or report bugs.

Citation

If you use rctd-py, please cite the original spacexr RCTD paper:

@article{cable2022robust,
  title={Robust decomposition of cell type mixtures in spatial transcriptomics},
  author={Cable, Dylan M and Murray, Evan and Zou, Luli S and Goeva, Aleksandrina and Macosko, Evan Z and Chen, Fei and Bhatt, Shreya and Denber, Hannah S and others},
  journal={Nature Biotechnology},
  volume={40},
  pages={517--526},
  year={2022},
  doi={10.1038/s41587-021-00830-w}
}

License

This project is licensed under the GNU General Public License v3.0.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rctd_py-0.1.4.tar.gz (3.9 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rctd_py-0.1.4-py3-none-any.whl (44.8 kB view details)

Uploaded Python 3

File details

Details for the file rctd_py-0.1.4.tar.gz.

File metadata

  • Download URL: rctd_py-0.1.4.tar.gz
  • Upload date:
  • Size: 3.9 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.10.7 {"installer":{"name":"uv","version":"0.10.7","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for rctd_py-0.1.4.tar.gz
Algorithm Hash digest
SHA256 566bab5ea7a234fa7d80bc837906e26d343469a1d3f580d119c62650ed3389b0
MD5 c082003b3f469be12ea5bd0f6aa84582
BLAKE2b-256 ae1b19eb7c4a8393d947ca28d09b8330e957f7eaffdeea34fcf193d76db76d1d

See more details on using hashes here.

File details

Details for the file rctd_py-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: rctd_py-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 44.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.10.7 {"installer":{"name":"uv","version":"0.10.7","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for rctd_py-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 7dce3929fe9366ffc7b2651c10fa6c6ebecd209b65f9e933e9386e9b113a5545
MD5 cddda29639386fd8bb25b82b7bfab264
BLAKE2b-256 0e36a0df8894d3b53df7edceb88076b666f439f319d3ea00418114be29c7f230

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page