Skip to main content

GPU-accelerated Robust Cell Type Decomposition (RCTD) for spatial transcriptomics

Project description

rctd-py

Tests PyPI Python License: GPL v3 codecov

JAX-accelerated Robust Cell Type Decomposition for spatial transcriptomics.

A Python reimplementation of the spacexr RCTD algorithm (Cable et al., 2022) with GPU acceleration via JAX. Deconvolves spatial transcriptomics data (Visium, Xenium, MERFISH, Slide-seq, etc.) into cell type proportions using a scRNA-seq reference.

Installation

uv pip install rctd-py

Or with standard pip:

pip install rctd-py

With CUDA support (GPU acceleration):

uv pip install "rctd-py[cuda]"

For development:

git clone https://github.com/p-gueguen/rctd-py.git
cd rctd-py
uv pip install -e ".[dev]"

Dependencies: jax, jaxlib, numpy, scipy, anndata

Quick Start

from rctd import Reference, run_rctd
import anndata

# 1. Load reference scRNA-seq data
ref_adata = anndata.read_h5ad("reference.h5ad")
reference = Reference(ref_adata, cell_type_col="cell_type")

# 2. Load spatial data
spatial = anndata.read_h5ad("spatial.h5ad")

# 3. Run RCTD
result = run_rctd(spatial, reference, mode="doublet")

The run_rctd function handles the full pipeline: gene intersection, platform effect normalization, sigma estimation, and per-pixel deconvolution.

See the tutorial notebook or the rendered tutorial for a complete walkthrough with synthetic data.

Deconvolution Modes

RCTD supports three modes, selected via the mode parameter:

Mode Description Use case
full Estimates weights for all K cell types per pixel using constrained IRWLS. Continuous mixtures, Visium
doublet Classifies each pixel as singlet or doublet, then estimates the top 1--2 cell type weights. Reports spot_class (singlet, doublet_certain, doublet_uncertain, reject). Slide-seq, sparse spatial data
multi Greedy forward selection of up to 4 cell types per pixel, adding types while the likelihood improves. Dense spatial platforms (Xenium, MERFISH)

Benchmarks

End-to-end (Xenium, 58k pixels, doublet mode)

Full pipeline on a 58k-pixel Xenium dataset (380 genes, 45 cell types):

Backend Sigma estimation Deconvolution Total
R spacexr (8 CPU cores) ~49 min ~2 min ~51 min
rctd-py — JAX GPU (Blackwell B200) ~3 min ~36s ~3.5 min
rctd-py — JAX GPU (L40S) ~3 min ~55 min* ~58 min

*The L40S is a rendering/inference GPU (GDDR6, 864 GB/s) rather than an HPC card. Its memory bandwidth is ~9× lower than the B200 (HBM3e, ~8 TB/s), making the memory-bound doublet IRWLS loop much slower. On HPC-class GPUs (H100, A100, B200) the deconvolution step completes in under 1 minute.

Sigma estimation uses a Poisson-Lognormal model with cubic spline interpolation. After optimisation (cached matrix inverse, precomputed spline coefficients, vmapped JAX evaluation), sigma drops from ~66 min to ~3 min — a ~23× speedup — and results are numerically identical.

IRWLS solver only

Solver throughput measured on the spacexr vignette dataset (71 pixels, 313 genes, 19 cell types), scaled to larger pixel counts:

Backend Pixels/sec Speedup vs R
R spacexr (single-core) ~62 1x
JAX CPU (16 threads) ~374 6x
JAX GPU (L40S) ~3,900 63x
JAX GPU (Blackwell B200) ~4,450 72x

GPU throughput saturates at ~3,900 pixels/sec on L40S at 7k+ pixels. JAX compilation overhead dominates at small pixel counts.

Validation

Validated against R spacexr on a Xenium dataset (45 cell types, 380 genes, ~58k filtered pixels):

Metric Value
Dominant type agreement 99.7%
Median per-pixel weight correlation 1.0000
Mean per-pixel weight correlation 0.9998
Pixels with correlation > 0.8 99.98%

Both implementations use identical parameters: UMI_min=20, doublet mode, constrain=FALSE for full-mode weight estimation. See the rendered tutorial for a walkthrough on synthetic data.

API Overview

run_rctd(spatial, reference, mode, config, batch_size)

End-to-end pipeline. Takes an AnnData spatial object and a Reference, returns a typed result (FullResult, DoubletResult, or MultiResult).

Reference(adata, cell_type_col, cell_min, n_max_cells, min_UMI)

Constructs cell type mean expression profiles from a scRNA-seq AnnData. Filters cell types below cell_min cells, caps per-type cells at n_max_cells, and removes cells below min_UMI.

RCTD(spatial, reference, config)

Stateful class for step-by-step control. Call fit_platform_effects() to normalize, then use run_full_mode, run_doublet_mode, or run_multi_mode directly.

RCTDConfig

Named tuple with all algorithm parameters. Key fields:

Parameter Default Description
UMI_min 100 Minimum UMI count per spatial pixel
UMI_min_sigma 300 Minimum UMI for sigma estimation
N_fit 1000 Number of pixels for sigma fitting
MAX_MULTI_TYPES 4 Maximum cell types in multi mode
CONFIDENCE_THRESHOLD 5.0 Singlet confidence threshold (doublet mode)
DOUBLET_THRESHOLD 20.0 Doublet certainty threshold
max_iter 50 IRWLS maximum iterations

Result Types

  • FullResult: weights (N x K), cell_type_names, converged
  • DoubletResult: weights (N x K), weights_doublet (N x 2), spot_class, first_type, second_type, first_class, second_class, min_score, singlet_score, cell_type_names
  • MultiResult: weights (N x K), sub_weights, cell_type_indices, n_types, conf_list, min_score, cell_type_names

GPU Usage

JAX automatically detects available GPUs. To enable CUDA:

pip install "rctd-py[cuda]"

This installs jax[cuda12]. Verify GPU detection:

import jax
print(jax.devices())  # [CudaDevice(id=0)]

Use the batch_size parameter in run_rctd to control GPU memory usage. The default (10,000 pixels per batch) works well for GPUs with 24+ GB VRAM.

Project Structure

src/rctd/
  __init__.py        # Public API exports
  _types.py          # RCTDConfig, FullResult, DoubletResult, MultiResult
  _reference.py      # Reference class (profile computation, DE gene selection)
  _rctd.py           # RCTD class and run_rctd pipeline
  _normalize.py      # Platform effect estimation (fit_bulk)
  _sigma.py          # Sigma (overdispersion) estimation
  _likelihood.py     # Poisson-Lognormal model, Q-matrix interpolation
  _irwls.py          # Batched IRWLS solver (JAX jit + vmap)
  _simplex.py        # Simplex projection for constrained optimization
  _full.py           # Full mode deconvolution
  _doublet.py        # Doublet mode deconvolution
  _multi.py          # Multi mode deconvolution

Contributing

Contributions are welcome! Please open an issue to discuss proposed changes or report bugs.

Citation

If you use rctd-py, please cite the original spacexr RCTD paper:

@article{cable2022robust,
  title={Robust decomposition of cell type mixtures in spatial transcriptomics},
  author={Cable, Dylan M and Murray, Evan and Zou, Luli S and Goeva, Aleksandrina and Macosko, Evan Z and Chen, Fei and Bhatt, Shreya and Denber, Hannah S and others},
  journal={Nature Biotechnology},
  volume={40},
  pages={517--526},
  year={2022},
  doi={10.1038/s41587-021-00830-w}
}

License

This project is licensed under the GNU General Public License v3.0.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rctd_py-0.1.3.tar.gz (3.9 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rctd_py-0.1.3-py3-none-any.whl (44.8 kB view details)

Uploaded Python 3

File details

Details for the file rctd_py-0.1.3.tar.gz.

File metadata

  • Download URL: rctd_py-0.1.3.tar.gz
  • Upload date:
  • Size: 3.9 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.10.7 {"installer":{"name":"uv","version":"0.10.7","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for rctd_py-0.1.3.tar.gz
Algorithm Hash digest
SHA256 451b7a09a693d47d5087090d04b24e9200d7e049d6d123c5de30c094d52c0125
MD5 08f5015560d293ef154defcc71b4fd4c
BLAKE2b-256 478ee0d1049abff459878f36e736b4bc24eaefeef2cd44187531792159a24cc2

See more details on using hashes here.

File details

Details for the file rctd_py-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: rctd_py-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 44.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.10.7 {"installer":{"name":"uv","version":"0.10.7","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for rctd_py-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 72c65cdf5c2124fbdf22b05fbbb9ca986a638ea443f2f5490611d5933eaffbb5
MD5 57f26edb0b88c0ab10235ee9054e10bd
BLAKE2b-256 09ddce09b781059803f510a5cf4da3d8dad5a9d744cffd5c3f8e7d58fa722dee

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page