Skip to main content

Learning multi-omics perturbation language

Project description

DensityFlow

A deep additive model for learning perturbation semantics.

Installation

  1. Create a virtual environment
conda create -n densityflow python=3.10 scipy numpy pandas scikit-learn && conda activate densityflow
  1. Install PyTorch following the official instruction.
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu126
  1. Install DensityFlow
pip3 install DensityFlow

Example

Dataset used in this example is obtained from scPerturb

import re
import scanpy as sc


from DensityFlow import DensityFlow
from DensityFlow.perturb import LabelMatrix
from sklearn.model_selection import train_test_split
from eval_metrics import mmd_eval, r2_score_eval, pearson_eval


# perturbation information
pert_col = 'perturbation'
control_label = 'control'
loss_func = 'multinomial'

# load single cell data
adata = sc.read_h5ad('PapalexiSatija2021_eccite_RNA.h5ad')
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)

# normalize perturbation labels
adata.obs[pert_col] = [re.sub(r'g\d+$', '', s) for s in adata.obs[pert_col]]


# split single cell data into training and test subsets
cells_pert = adata[adata.obs[pert_col]!=control_label].obs_names
cells_train, cells_test = train_test_split(cells_pert, test_size= adata.shape[0] // 8)
cells_train = cells_train.tolist() + adata[adata.obs[pert_col]==control_label].obs_names.tolist()
adata_train = adata[cells_train].copy()
adata_test = adata[cells_test].copy()


# prepare data for training
xs = adata_train.X

lb = LabelMatrix()
us = lb.fit_transform(adata_train.obs[pert_col],control_label)
ln = lb.labels_

# training model
model = DensityFlow(input_size = xs.shape[1],
                    cell_factor_size=us.shape[1],
                    loss_func = loss_func,
                    seed=42,
                    use_cuda=True)

model.fit(xs, us=us, num_epochs=200, batch_size=1000, use_jax=True)


# save model
# DensityFlow.save_model(model, f'densityflow_{loss_func}_model.pt')

# load pre-trained model
# model = DensityFlow.load_model(f'densityflow_{loss_func}_model.pt')


# evaluation
def predict_pert_effect(ad,pert):
    ad = ad.copy()
    xs_pert = ad.X.toarray()
    zs_basal = model.get_basal_embedding(xs_pert, show_progress=False)

    ind = int(np.where(ln==pert)[0])
    us_pert = np.ones([xs_pert.shape[0],1])
    dzs = model.get_cell_shift(ad.X.toarray(), perturb_idx=ind, perturb_us=us_pert, show_progress=False)
    
    counts = model.get_counts(zs_basal+dzs, library_sizes=ad.X.sum(1), show_progress=False)
    return counts.copy()


results = []
pert_sets = adata_test.obs[pert_col].unique().tolist()
i = 0
for pert in pert_sets:
    i += 1
    print(f'{i}/{len(pert_sets)}')
    
    if pert==control_label:
        continue
    
    ad_test = adata_test[adata_test.obs[pert_col]==pert].copy()
    xs_test = ad_test.X.toarray()
    
    ind = np.random.choice(np.arange(adata_control.shape[0]), size=ad_test.shape[0], replace=True)
    ad_ctrl = adata_control[ind].copy()
    ad_ctrl.obs_names_make_unique()
    xs_basal = ad_ctrl.X.toarray()
    
    xs_test_pred = predict_pert_effect(ad_test, pert)
    
    xs_test_pred = xs_test_pred.astype(float)
    xs_test = xs_test.astype(float)
    xs_basal = xs_basal.astype(float)
    
    mmd_value=mmd_eval(xs_test_pred, xs_test)
    r2 = r2_score_eval(xs_test_pred, xs_test)
    pr = pearson_eval(xs_test_pred-xs_basal,xs_test-xs_basal)
    print(f'mmd:{mmd_value}; r2:{r2}; pearson:{pr}')
    results.append({'mmd':mmd_value,'r2':r2,'pearson':pr})


df = pd.DataFrame(results)
df.mean(0)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

densityflowmo-0.1.3.tar.gz (30.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

densityflowmo-0.1.3-py3-none-any.whl (33.8 kB view details)

Uploaded Python 3

File details

Details for the file densityflowmo-0.1.3.tar.gz.

File metadata

  • Download URL: densityflowmo-0.1.3.tar.gz
  • Upload date:
  • Size: 30.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.18

File hashes

Hashes for densityflowmo-0.1.3.tar.gz
Algorithm Hash digest
SHA256 9467da8f6db2d4491702005148841a5d5b712288a09222fdbb435e8425832ea3
MD5 17a11cbf10962ba08c84f4184becb532
BLAKE2b-256 279520a83bed362544d6972c918db079d148daa9434bea8c5b3a996fc23bb8bc

See more details on using hashes here.

File details

Details for the file densityflowmo-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: densityflowmo-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 33.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.18

File hashes

Hashes for densityflowmo-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 b39fcbe9abb45acfa4215112de2758ec22d66e6bc042f83c7bafa0d6b2dd56a2
MD5 fddc579a94c7707c205a334bf5ca270f
BLAKE2b-256 14959ebadc266e71cf4a76cc5adb819b9f90041f4a4f983aab9ce8940633b3dd

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page