A single-cell foundation model focus on the spatial cell-cell colocalization and subcellular mulecular co-occurrence
Project description
This is the official SpatialFormer codebase. SpatialFormer is the first single-cell spatial foundation model that learns universal representations of subcellular molecular and cellular spatial proximity through multi-task learning.
Overview
Spatial transcriptomics quantifies gene expression within its spatial context, making significant advances in biomedical research possible. Understanding the spatial expression of genes and how multicellular systems are organised is vital for diagnosing diseases and studying biological processes. However, existing models often struggle to effectively integrate gene expression data with cellular spatial information. In this study, we introduce SpatialFormer: a hybrid framework that combines convolutional networks and transformers in order to learn single-cell multi-scale information within a niche context. This includes expression data and the subcellular spatial distribution of genes. Pre-trained on 700 million cell pairs from 17 million spatially resolved single cells across 71 Xenium slides, SpatialFormer merges gene spatial expression profiles with cell niche information via a pairwise training strategy. Our findings demonstrate that SpatialFormer can distil biological signals across various tasks, including single-cell batch correction, cell-type annotation, co-localisation detection and the identification of gene pairs critical to immune cell-cell interactions involved in the regulation of lung fibrosis. These advancements enhance our understanding of cellular dynamics and open up new avenues for applications in biomedical research.
Updates
[2025-12-27]
🚀 Data Scale-Up
- Transcripts: 3.3B → 4.5B
- Cells: 13M → 17M
- Slides: 61 → 71
- Gene vocabulary: 1,922 → 6,036
🧠 Model & Training
- Added a new edge-based dataloader; anchor with preselected index with:
- distance-aware sampling
- hard negative pairs
- easy negative pairs cache-pairs
- faiss-based nearest neighbors search cache-faiss
- index-based storage for p/n pairs, which save large amount of memory usage
- Upgraded to GraphSAGE v2, supporting 6,036 spatial embeddings
- Integrated FlashAttention v2 for efficient long-sequence processing
🧠 Prediction
- Aligning everything of prediction with the sp.tl.embed_data function, update sp.tl.embed_data to process variable lengths
🧠 Embedding extraction
- The embeddings can be extracted more efficient with larger batch size and representative sequence length.
Tutorials
For the instructions of SpatialFormer, please refer to our jupyter notebook (some in the .py files) tutorials on:
The zero-shot tutorials
- Dataset Integration
- Gene-gene colocalization perturbation discovery
- Gene-gene colocalization attention analysis
- Gene-gene colocalization perturbation analysis 1
- Gene-gene colocalization perturbation analysis 2
- Gene-gene colocalization perturbation analysis 3
- Cell-cell colocalization analysis
- Cell-cell colocalization prediction
The fine-tuning tutorials
System Requirements
Hardware requirements
We provide the GPU and CPU version for users with different device levels. However, if a large scale of cells need to be calculated, the GPUs is mandatory to get the results effeciently. When using GPUs, AMD and NVIDIA GPUs are all supported.
Software requirements
OS requirements
This package is supported for macOS and Linux. The package has been tested on the following systems:
- macOS: Sequoia (15.3.1)
- Linux: Ubuntu 16.04; SLES 15.5
Python environment requirements
Create the spatialformer environment by anaconda (python >= 3.10 required)
conda create -n spatialformer python=3.10
Then, enter the spatialformer environment
source activate spatialformer
Installation
Step 1: Install PyTorch
PyTorch must be installed before spatialformer to ensure compatibility with your operating system and GPU.
Linux (AMD GPU — ROCm 6.0)
pip install torch==2.3.1+rocm6.0 torchvision==0.18.1+rocm6.0 torchaudio==2.3.1+rocm6.0 --index-url https://download.pytorch.org/whl/rocm6.0
Linux (NVIDIA GPU — CUDA 12.1)
pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121
macOS
pip install torch torchvision torchaudio
Note: On Mac, only CPUs are currently supported.
Step 2: Install spatialformer
Make sure cmake already installed, otherwise
conda install cmake
pip install spatialformer
Step 3 (Optional): Install FlashAttention
FlashAttention is required to accelerate training and inference while maintaining accuracy.
Before that, CUDA compiler (nvcc) should be detected in your device. nvcc can be installed via
conda install -c "nvidia/label/cuda-12.4.0" cuda-toolkit
#check the installation of nvcc
nvcc --version
When compilation is ready, let's install the flash-attention
To get started with the triton backend for AMD, follow the steps below. FlashAttention-2 ROCm CK backend currently supports (reference):
- MI200x, MI250x, MI300x, and MI355x GPUs.
- Datatype fp16 and bf16
- Both forward's and backward's head dimensions up to 256.
pip install triton==3.2.0
Then install the FlashAttention(2.X) from the github
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention
git checkout 35e5f00
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
python setup.py install
pip install einops
Finally, test whether it works normally.
pytest tests/test_flash_attn.py
Or easily by
python -c "
import torch
from flash_attn import flash_attn_func
q = torch.randn(2, 128, 8, 64, dtype=torch.float16, device='cuda')
k = torch.randn(2, 128, 8, 64, dtype=torch.float16, device='cuda')
v = torch.randn(2, 128, 8, 64, dtype=torch.float16, device='cuda')
out = flash_attn_func(q, k, v)
print(f'✅ Flash Attention on {torch.cuda.get_device_name(0)}: {out.shape}')
"
Alternatively, if you are using NVIDIA(e.g., A100), please easily run the following code to install FlashAttention(2.X)
pip install flash-attn --no-build-isolation
if failed try the pre-built wheel
wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.8/flash_attn-2.5.8+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
pip install ./flash_attn-2.5.8+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
We implement the FlashAttention(2.x) in our code, which is completely reweited and 2x faster than FlashAttention(1.x).
Pretraining data
The model is capable of handling input from individual cells and doublets. It was originally pretrained on a large-scale dataset of pairwise doublets with both positive and negative characteristics. Specifically, the positive pairs consist of all cells located within the niches of a certain query cell. In contrast, the negative pairs can include any distant cells that are either far away from the query cell.
The processed individual cell dataset can be retrieved from the Hugging Face dataset repository at SpatialCC-17M. The pairwise data can be generated by following the instructions provided in /data_preprocess/.
You can easily download the dataset in python as below
from datasets import load_dataset
spatialcc = load_dataset("TerminatorJ/xenium_5k_pandavid_dataset_v2", cache_dir = "your_cache_dir")
Get the Embeddings
SpatialFormer provides a simple function to extract embeddings. By using the sp.tl.embed() function, we can seamlessly integrate with the AnnData object, meaning the generated embeddings will be stored in obsm under the key "X_SpaF".
SpatialFormer supports two methods for generating embeddings: 1) single input mode and 2) pairwise input mode. Below is an example of generating the AnnData embeddings:
The checkpoints can be downloaded according to different use cases as below:
| Input type | Tissue types | Size (number of slides) | Links |
|---|---|---|---|
| Paired | lung | 1 | ckp_pair_lung_1 |
| Paired | 13 types | 61 | ckp_pair_13tissues_61 |
| Paired | 13 types | 71 | ckp_pair_13tissues_71 |
| Paired | lung | 25 | ckp_pair_lung_25 |
| Single | 13 types | 62 | ckp_single_13tissues_62 |
| Single | 13 types | 71 | ckp_single_13tissues_71 |
The LoRA fine-tuned checkpoints can be downloaded as below:
| Input type | Tissue types | Size (number of slides) | Cell Number | Links |
|---|---|---|---|---|
| Paired | lung | 1 | 10k | ckp_pair_lung_LoRA_10K |
| Paired | breast | 1 | 10k | ckp_pair_breast_LoRA_10K |
| Paired | colon | 1 | 10k | ckp_pair_colon_LoRA_10K |
| Paired | lung | 1 | 100k | ckp_pair_lung_LoRA_100K |
| Paired | breast | 1 | 100k | ckp_pair_breast_LoRA_100K |
| Paired | colon | 1 | 100k | ckp_pair_colon_LoRA_100K |
SpatialFormer is mainly focus on the zero-shot learning for the single-cell spatial omics data. Therefore, extracting the embeddings should be the most frequently used in the downstream tasks. We support diversed input format for extracting the cell embeddings. The input can be ".h5ad", or "huggingface dataset".
For the easiest implementation, ".h5ad" file can easily input and get the embedding out following the codes as below:
We also provide Google Colab for practical purpose.
Loading the anndata
A simple example anndata can be downloaded here
import scanpy as sc
adata = sc.read_h5ad("./downstream/cell_cell_communication/data/covid_subsampled.h5ad")
make sure the "gene_name" column is in the adata.var column names
Single Input Mode
import spatialformer as sp
method = "cls"
tissue = "Lung"
condition = "Disease"
assay = "Xenium"
model_ckp_path = "./ckp_single_13tissues_71.ckpt" # "ckp_single_13tissues_71" is recommended
use_flash_attn = True # Depends on whether you install the FlashAttention, if installed -> "True", "False" instead.
batch_size = 16
embed_adata = sp.tl.embed_data(
adata = adata,
tissue = tissue,
condition = condition,
assay = assay,
method = method,
model_ckp_path = model_ckp_path,
batch_size = batch_size,
mode = "single",
use_flash_attn = use_flash_attn,
num_workers = 32
)
Pairwise Input Mode
import spatialformer as sp
method = "cls"
tissue = "Lung"
condition = "Disease"
assay = "Xenium"
model_ckp_path = "./ckp_pair_13tissues_71.ckpt" #"ckp_pair_13tissues_71" is recommended
batch_size = 16
embed_adata = sp.tl.embed_data(
adata = adata,
tissue = tissue,
condition = condition,
assay = assay,
method = method,
model_ckp_path = model_ckp_path,
batch_size = batch_size,
mode = "pair",
left_cell = ["20532-0-1-0-1", "222101-0-0-1"],
right_cell = ["483188-0-0-1", "513429-0-0-1"],
num_workers = 16
)
| Arguments | dtype | Description |
|---|---|---|
| adata | object | An AnnData object that stores expression information by CellXGene. |
| tissue | string | The type of tissue (e.g., Breast/Lung). |
| condition | string | Metadata for the sample condition (e.g., Disease/Healthy). |
| assay | string | The method of getting the data (e.g. Merfish, Xenium). |
| method | Embedding extraction method. "cls": Use CLS token embedding as cell representation; "gene": Use the mean of gene token embeddings. | |
| mode | string | The method of the embed function, which can be either "single" or "pair." The single mode collates only individual cells as input for the model. In "pair" mode, data is prepared for pairwise input. If using "pair," both left_cell and right_cell must be provided. Each cell ID in left_cell corresponds to the cell ID at the same index in right_cell. |
| model_ckp_path | string | The path to the SpatialFormer model checkpoint. |
| batch_size | integer | The batch size for the data loader. |
| threshold | float | The threshold for filtering whether two genes are paired, which helps in identifying confidently paired genes at subcellular resolution. This option is applicable only in "single" input mode and is not functional in "pair" mode. |
| left_cell | array_like | A list of cell IDs representing the query cells. |
| right_cell | array_like | A list of cell IDs representing the key cells. |
| num_workers | integer | The number of CPU cores to load the data. This value should match the number of workers specified in the data loader. |
| resume_before_5k | bool | Indicates whether to resume from a checkpoint trained on the small panel. Set to True to use the small-panel checkpoint; set to False to use the checkpoint trained with the 5k Xenium panel. |
| max_len | integer | Maximum length of each sequence considered. Default is None, meaning all genes are used. For large numbers of pairwise sequences, we strongly recommend setting this to 500 per sequence to significantly improve runtime performance if FlashAttention is not installed. |
If the input data is a huggingface dataset, we have built a huggingface specified dataloader only for inference step:
from datasets import load_from_disk,concatenate_datasets,load_dataset
def load_model(model_ckp_path, device):
get_file_path = lambda path, filename: os.path.join("/scratch/project_465001820/Spatialformer", path, filename)
config_path = get_file_path("config", "_config_train_large_pair.json")
with open(config_path, 'r') as json_file:
config = json.load(json_file)
model = manual_train_fm(config = config)
ckp = torch.load(model_ckp_path, map_location=torch.device(device))
params = ckp["state_dict"]
model.load_state_dict(params)
model.eval()
model.to(device)
return model
model_ckp_path = "/scratch/project_465001027/Spatialformer/output/checkpoints/step=0104000-train_total_loss=-2.3064-val_total_loss=0.0000.ckpt"
model = load_model(model_ckp_path, "cuda")
dataloader = create_single_data_loaders(lung_dataset, #define your own dataset here
cls_token = 1,
padding_idx = 0,
sep_token = 1949,
batch_size=batch_size,
context_length=500,
special_token_num = 4,
split_num = 1,
num_workers = 64,
mode="eval")
all_embeds = []
with torch.no_grad():
for i, batch in tqdm(enumerate(dataloader)):
counter += batch_size
tissues = batch["Tissues"]
conditions = batch["Conditions"]
anns = batch["Annotations"]
attn_mask = batch["attention_mask"]
embeddings, _ = model.get_embeddings(batch, [-1], True, False) #normal prob
embeddings = embeddings[0][:,0,:].detach().cpu().numpy()
all_embeds.append(embeddings)
Training the model
The model can be further pretrained with the following codes. Get the script/train.py for pretraining as below:
The parameters of the configuration can refer to the table
Pretrain the singular input model
python ./script/train.py --config /scratch/project_465001820/Spatialformer/config/_config_train_large_single.json
Pretrain the doublet input model
python ./script/train.py --config /scratch/project_465001820/Spatialformer/config/_config_train_large_pair.json
Fine-tune the model
For each slide, the accurate prediction of the molecular features largely rely on the cell-cell colocalization. We use LoRA to fine-tune the SpatialFormer model with one slide.
We also provide Google Colab, which makes it easy to practice.
python cell_cell_communication_zero_shot_multi_platform.py --radius 30 --fine_tune_mode lora --rank 64 --lora_alpha 128 --cell_by_gene_path /scratch/project_465001820/Spatialformer_main_practice/data/MERFISH_Lung/HumanLungCancerPatient1_cell_by_gene.csv --cell_meta_path /scratch/project_465001820/Spatialformer_main_practice/data/MERFISH_Lung/HumanLungCancerPatient1_cell_metadata.csv --sample_name MERFISH_Lung --zero_shot_cell_size 500 --tissue Lung --condition Disease --config_path /scratch/project_465001820/Spatialformer/spatialformer/config/_config_fine_tune_probe.json --batch_size 32 --max_cells 10000
Reproducibility of the work
All the codes for reproducing the results of the manuscript were presented in the ./downstream directory. For reproducing the MERFISH and Xenium colocalization prediction, colocalization prediction
Star Trend
Cite our work
Wang J, Huang Y, Winther O. SpatialFormer: Universal Spatial Representation Learning from Subcellular Molecular to Multicellular Landscapes[J]. bioRxiv, 2025: 2025.01. 18.633701.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file spatialformer-0.1.8.tar.gz.
File metadata
- Download URL: spatialformer-0.1.8.tar.gz
- Upload date:
- Size: 10.9 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.20
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3f963d235478edc457a3ecc3af317c3f5bb5006b75fd0cb570f82a1455a344d5
|
|
| MD5 |
fa2f1258aa62d444f552418acc6be85a
|
|
| BLAKE2b-256 |
5a5662b2fd12b63f80f29bff02f1c38eccbc10ee5ae6a0666b5bdc870185d07c
|
File details
Details for the file spatialformer-0.1.8-py3-none-any.whl.
File metadata
- Download URL: spatialformer-0.1.8-py3-none-any.whl
- Upload date:
- Size: 10.9 MB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.20
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
606c483ad6c9375c8a5db62d0005aeb80596beaa39b6a33123cbdb62846fe542
|
|
| MD5 |
70885f5ead58a341c361da5a8e38ac3c
|
|
| BLAKE2b-256 |
41f1f12767fcbd6997ea51f49966ffb94b923134ae0ae96266aeb7e6bd57066b
|