Skip to main content

Generative additive model for single cell perturbation v2

Project description

DensityFlow

A deep additive model for learning perturbation semantics.

Installation

  1. Create a virtual environment
conda create -n densityflow python=3.10 scipy numpy pandas scikit-learn && conda activate densityflow
  1. Install PyTorch following the official instruction.
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu126
  1. Install DensityFlow
pip3 install DensityFlow

Example

Dataset used in this example is obtained from scPerturb

import re
import scanpy as sc


from DensityFlow import DensityFlow
from DensityFlow.perturb import LabelMatrix
from sklearn.model_selection import train_test_split
from eval_metrics import mmd_eval, r2_score_eval, pearson_eval


# perturbation information
pert_col = 'perturbation'
control_label = 'control'
loss_func = 'multinomial'

# load single cell data
adata = sc.read_h5ad('PapalexiSatija2021_eccite_RNA.h5ad')
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)

# normalize perturbation labels
adata.obs[pert_col] = [re.sub(r'g\d+$', '', s) for s in adata.obs[pert_col]]


# split single cell data into training and test subsets
cells_pert = adata[adata.obs[pert_col]!=control_label].obs_names
cells_train, cells_test = train_test_split(cells_pert, test_size= adata.shape[0] // 8)
cells_train = cells_train.tolist() + adata[adata.obs[pert_col]==control_label].obs_names.tolist()
adata_train = adata[cells_train].copy()
adata_test = adata[cells_test].copy()


# prepare data for training
xs = adata_train.X

lb = LabelMatrix()
us = lb.fit_transform(adata_train.obs[pert_col],control_label)
ln = lb.labels_

# training model
model = DensityFlow(input_size = xs.shape[1],
                    cell_factor_size=us.shape[1],
                    loss_func = loss_func,
                    seed=42,
                    use_cuda=True)

model.fit(xs, us=us, num_epochs=200, batch_size=1000, use_jax=True)


# save model
# DensityFlow.save_model(model, f'densityflow_{loss_func}_model.pt')

# load pre-trained model
# model = DensityFlow.load_model(f'densityflow_{loss_func}_model.pt')


# evaluation
def predict_pert_effect(ad,pert):
    ad = ad.copy()
    xs_pert = ad.X.toarray()
    zs_basal = model.get_basal_embedding(xs_pert, show_progress=False)

    ind = int(np.where(ln==pert)[0])
    us_pert = np.ones([xs_pert.shape[0],1])
    dzs = model.get_cell_shift(ad.X.toarray(), perturb_idx=ind, perturb_us=us_pert, show_progress=False)
    
    counts = model.get_counts(zs_basal+dzs, library_sizes=ad.X.sum(1), show_progress=False)
    return counts.copy()


results = []
pert_sets = adata_test.obs[pert_col].unique().tolist()
i = 0
for pert in pert_sets:
    i += 1
    print(f'{i}/{len(pert_sets)}')
    
    if pert==control_label:
        continue
    
    ad_test = adata_test[adata_test.obs[pert_col]==pert].copy()
    xs_test = ad_test.X.toarray()
    
    ind = np.random.choice(np.arange(adata_control.shape[0]), size=ad_test.shape[0], replace=True)
    ad_ctrl = adata_control[ind].copy()
    ad_ctrl.obs_names_make_unique()
    xs_basal = ad_ctrl.X.toarray()
    
    xs_test_pred = predict_pert_effect(ad_test, pert)
    
    xs_test_pred = xs_test_pred.astype(float)
    xs_test = xs_test.astype(float)
    xs_basal = xs_basal.astype(float)
    
    mmd_value=mmd_eval(xs_test_pred, xs_test)
    r2 = r2_score_eval(xs_test_pred, xs_test)
    pr = pearson_eval(xs_test_pred-xs_basal,xs_test-xs_basal)
    print(f'mmd:{mmd_value}; r2:{r2}; pearson:{pr}')
    results.append({'mmd':mmd_value,'r2':r2,'pearson':pr})


df = pd.DataFrame(results)
df.mean(0)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

densityflow2-1.1.3.tar.gz (30.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

densityflow2-1.1.3-py3-none-any.whl (34.1 kB view details)

Uploaded Python 3

File details

Details for the file densityflow2-1.1.3.tar.gz.

File metadata

  • Download URL: densityflow2-1.1.3.tar.gz
  • Upload date:
  • Size: 30.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.18

File hashes

Hashes for densityflow2-1.1.3.tar.gz
Algorithm Hash digest
SHA256 c6e1c012ebe4a784e2e5f0d79a1d75dced06af5c513f2c4ef5733b43581cc782
MD5 29365bdd7cb02b870262705132a10b10
BLAKE2b-256 ccb9acaf149acd4553259e4f466ea4f931f618aec0cbb9b191ba906c4ee2bc7b

See more details on using hashes here.

File details

Details for the file densityflow2-1.1.3-py3-none-any.whl.

File metadata

  • Download URL: densityflow2-1.1.3-py3-none-any.whl
  • Upload date:
  • Size: 34.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.18

File hashes

Hashes for densityflow2-1.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 499a08662a35cfc4b361c99df9b893a37aa67487d09ff401077c082fe36245dd
MD5 5b356e7d2ac49303f9c9adcbad8f6281
BLAKE2b-256 4d1cd7c7951b32073ba3a56cc6ff8785fdeabeb6c04a736fd5c359d401691bf7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page