Pytorch version of MUSE
Project description
Example Package
MUSE-PyTorch
This repo is a pytorch-lightning implementation of the official MUSE: multi-modality structured embedding for spatially resolved transcriptomics analysis:
Bao, F., Deng, Y., Wan, S. et al. Integrative spatial analysis of cell morphologies and transcriptional states with MUSE. Nat Biotechnol (2022). https://doi.org/10.1038/s41587-022-01251-z
Requirements
- numpy==1.22.3
- online_triplet_loss==0.0.6
- pandas==1.4.2
- PhenoGraph==1.5.7
- pytorch_lightning==1.6.3
- scipy==1.8.0
- torch==1.11.0
Installation
To install MUSE PyTorch package, use
pip install muse_pytorch
Usage
import muse_pytorch as muse
The library exposes the same fit_predict method as the orignial one.
z, x_hat, y_hat, latent_x, latent_y = muse.fit_predict(trans_features,
morph_features,
trans_labels,
morph_labels,
init_epochs=3,
refine_epochs=3,
cluster_epochs=6,
cluster_update_epoch=2,
joint_latent_dim=50,
batch_size=512)
The method expects the same parameters, and more.
Parameters:
trans_features: input for transcript modality; matrix of n * p, where n = number of cells, p = number of genes.
morph_features: input for morphological modality; matrix of n * q, where n = number of cells, q is the feature dimension.
trans_labels: initial reference cluster label for transcriptional modality.
morph_labels: inital reference cluster label for morphological modality.
latent_dim: size of the latent dimension for the single modalities
joint_latent_dim: size of the latent dimension of the joint representation
lambda_reg: factor for the regularisation term in the loss function
lambda_sup: factor for the self-supervised term in the loss function
lr: learning rate for the optimizer
init_epochs: epochs for the initializing phase
refine_epochs: epochs for the refining phase
cluster_epochs: epochs for the clustering phase
cluster_update_epoch: interval after which the single modality clusters will be updated
batch_size: batch size for the dataloaders
Outputs:
z: joint latent representation learned by MUSE.
x_hat: reconstructed feature matrix corresponding to input data_x.
y_hat: reconstructed feature matrix corresponding to input data_y.
h_x: modality-specific latent representation corresponding to data_x.
h_y: modality-specific latent representation corresponding to data_y.
On top of this, it is also possible to further personalize the training by importing PyTorch Lightning Module & Datamodule
from muse_pytorch import MUSE, MUSEDataModule
For a complete description of the project please head to the original repo
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file muse-pytorch-0.0.3.tar.gz
.
File metadata
- Download URL: muse-pytorch-0.0.3.tar.gz
- Upload date:
- Size: 9.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 784075b54f6bb4d06d90ac9642be126a5e4296259271e757673d97f479795890 |
|
MD5 | d7e57b357138721afbc739955f9156f4 |
|
BLAKE2b-256 | 31fb4bb1d624b9b153f19de4b98d7d2fe9da3829ce2bc8d17167e9110748761c |
File details
Details for the file muse_pytorch-0.0.3-py3-none-any.whl
.
File metadata
- Download URL: muse_pytorch-0.0.3-py3-none-any.whl
- Upload date:
- Size: 9.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a167ee7edf3edbdccf7fb1e5f0f405fff9d3656d98311b64130eda665ecc3000 |
|
MD5 | 3e6e3f09aa29f49abe3d4cd3896159ec |
|
BLAKE2b-256 | 0c9f030af4add43756073c81acb763f9ab0209bd8076fff58a716b8823b7ae04 |