Pytorch version of MUSE
Project description
Example Package
MUSE-PyTorch
This repo is a pytorch-lightning implementation of the official MUSE: multi-modality structured embedding for spatially resolved transcriptomics analysis:
Bao, F., Deng, Y., Wan, S. et al. Integrative spatial analysis of cell morphologies and transcriptional states with MUSE. Nat Biotechnol (2022). https://doi.org/10.1038/s41587-022-01251-z
Requirements
- numpy==1.22.3
- online_triplet_loss==0.0.6
- pandas==1.4.2
- PhenoGraph==1.5.7
- pytorch_lightning==1.6.3
- scipy==1.8.0
- torch==1.11.0
Installation
To install MUSE PyTorch package, use
pip install muse_pytorch
Usage
import muse_pytorch as muse
The library exposes the same fit_predict method as the orignial one.
z, x_hat, y_hat, latent_x, latent_y = muse.fit_predict(trans_features,
morph_features,
trans_labels,
morph_labels,
init_epochs=3,
refine_epochs=3,
cluster_epochs=6,
cluster_update_epoch=2,
joint_latent_dim=50,
batch_size=512)
The method expects the same parameters, and more.
Parameters:
trans_features: input for transcript modality; matrix of n * p, where n = number of cells, p = number of genes.
morph_features: input for morphological modality; matrix of n * q, where n = number of cells, q is the feature dimension.
trans_labels: initial reference cluster label for transcriptional modality.
morph_labels: inital reference cluster label for morphological modality.
latent_dim: size of the latent dimension for the single modalities
joint_latent_dim: size of the latent dimension of the joint representation
lambda_reg: factor for the regularisation term in the loss function
lambda_sup: factor for the self-supervised term in the loss function
lr: learning rate for the optimizer
init_epochs: epochs for the initializing phase
refine_epochs: epochs for the refining phase
cluster_epochs: epochs for the clustering phase
cluster_update_epoch: interval after which the single modality clusters will be updated
batch_size: batch size for the dataloaders
Outputs:
z: joint latent representation learned by MUSE.
x_hat: reconstructed feature matrix corresponding to input data_x.
y_hat: reconstructed feature matrix corresponding to input data_y.
h_x: modality-specific latent representation corresponding to data_x.
h_y: modality-specific latent representation corresponding to data_y.
On top of this, it is also possible to further personalize the training by importing PyTorch Lightning Module & Datamodule
from muse_pytorch import MUSE, MUSEDataModule
For a complete description of the project please head to the original repo
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for muse_pytorch-0.0.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0ad073fe1d396d632079f11cbbe2ee7d14c863517f893e6b36ee3fa1719a356b |
|
MD5 | a6e3e3141440426ab4711d5e74a0a78c |
|
BLAKE2b-256 | 5b362afb7aa4faed2a7444de41d679d66bbcc573bcb34bb15b2645e2274d0652 |