Skip to main content

Generative additive model for single cell perturbation

Project description

DensityFlow

A deep additive model for learning perturbation semantics.

Installation

  1. Create a virtual environment
conda create -n densityflow python=3.10 scipy numpy pandas scikit-learn && conda activate densityflow
  1. Install PyTorch following the official instruction.
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu126
  1. Install DensityFlow
pip3 install DensityFlow

Example

Dataset used in this example is obtained from scPerturb

import re
import scanpy as sc


from DensityFlow import DensityFlow
from DensityFlow.perturb import LabelMatrix
from sklearn.model_selection import train_test_split
from eval_metrics import mmd_eval, r2_score_eval, pearson_eval


# perturbation information
pert_col = 'perturbation'
control_label = 'control'
loss_func = 'multinomial'

# load single cell data
adata = sc.read_h5ad('PapalexiSatija2021_eccite_RNA.h5ad')
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)

# normalize perturbation labels
adata.obs[pert_col] = [re.sub(r'g\d+$', '', s) for s in adata.obs[pert_col]]


# split single cell data into training and test subsets
cells_pert = adata[adata.obs[pert_col]!=control_label].obs_names
cells_train, cells_test = train_test_split(cells_pert, test_size= adata.shape[0] // 8)
cells_train = cells_train.tolist() + adata[adata.obs[pert_col]==control_label].obs_names.tolist()
adata_train = adata[cells_train].copy()
adata_test = adata[cells_test].copy()


# prepare data for training
xs = adata_train.X

lb = LabelMatrix()
us = lb.fit_transform(adata_train.obs[pert_col],control_label)
ln = lb.labels_

# training model
model = DensityFlow(input_size = xs.shape[1],
                    cell_factor_size=us.shape[1],
                    loss_func = loss_func,
                    seed=42,
                    use_cuda=True)

model.fit(xs, us=us, num_epochs=200, batch_size=1000, use_jax=True)


# save model
# DensityFlow.save_model(model, f'densityflow_{loss_func}_model.pt')

# load pre-trained model
# model = DensityFlow.load_model(f'densityflow_{loss_func}_model.pt')


# evaluation
def predict_pert_effect(ad,pert):
    ad = ad.copy()
    xs_pert = ad.X.toarray()
    zs_basal = model.get_basal_embedding(xs_pert, show_progress=False)

    ind = int(np.where(ln==pert)[0])
    us_pert = np.ones([xs_pert.shape[0],1])
    dzs = model.get_cell_shift(ad.X.toarray(), perturb_idx=ind, perturb_us=us_pert, show_progress=False)
    
    counts = model.get_counts(zs_basal+dzs, library_sizes=ad.X.sum(1), show_progress=False)
    return counts.copy()


results = []
pert_sets = adata_test.obs[pert_col].unique().tolist()
i = 0
for pert in pert_sets:
    i += 1
    print(f'{i}/{len(pert_sets)}')
    
    if pert==control_label:
        continue
    
    ad_test = adata_test[adata_test.obs[pert_col]==pert].copy()
    xs_test = ad_test.X.toarray()
    
    ind = np.random.choice(np.arange(adata_control.shape[0]), size=ad_test.shape[0], replace=True)
    ad_ctrl = adata_control[ind].copy()
    ad_ctrl.obs_names_make_unique()
    xs_basal = ad_ctrl.X.toarray()
    
    xs_test_pred = predict_pert_effect(ad_test, pert)
    
    xs_test_pred = xs_test_pred.astype(float)
    xs_test = xs_test.astype(float)
    xs_basal = xs_basal.astype(float)
    
    mmd_value=mmd_eval(xs_test_pred, xs_test)
    r2 = r2_score_eval(xs_test_pred, xs_test)
    pr = pearson_eval(xs_test_pred-xs_basal,xs_test-xs_basal)
    print(f'mmd:{mmd_value}; r2:{r2}; pearson:{pr}')
    results.append({'mmd':mmd_value,'r2':r2,'pearson':pr})


df = pd.DataFrame(results)
df.mean(0)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

densityflow-1.0.27.tar.gz (41.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

densityflow-1.0.27-py3-none-any.whl (44.7 kB view details)

Uploaded Python 3

File details

Details for the file densityflow-1.0.27.tar.gz.

File metadata

  • Download URL: densityflow-1.0.27.tar.gz
  • Upload date:
  • Size: 41.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.18

File hashes

Hashes for densityflow-1.0.27.tar.gz
Algorithm Hash digest
SHA256 eccd8e1d2cb893fe3a3d6ae92772d301e1b3c64d6eb29f53320bde204b43695c
MD5 45254af77603d01aae2df8c9614793ce
BLAKE2b-256 bb8e738a008b4400addc56f2f9d10e351dbccbcacc399381f41f008502888d97

See more details on using hashes here.

File details

Details for the file densityflow-1.0.27-py3-none-any.whl.

File metadata

  • Download URL: densityflow-1.0.27-py3-none-any.whl
  • Upload date:
  • Size: 44.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.18

File hashes

Hashes for densityflow-1.0.27-py3-none-any.whl
Algorithm Hash digest
SHA256 d83e511e2baaca97b97b81b790e585cff14f17e33eae97524e50fb08b00b5d1f
MD5 23bd6ba13bcb5fb24c03b7a974270d13
BLAKE2b-256 a102c1162369bd842eee9fbc11e3a20dd29fdcbe4d556a42b099aac07b61b1b5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page