Skip to main content

CellVGAE uses the connectivity between cells (such as k-nearest neighbour graphs) with gene expression values as node features to learn high-quality cell representations in a lower-dimensional space

Project description

CellVGAE

An unsupervised scRNA-seq analysis workflow with graph attention networks

CellVGAE uses the connectivity between cells (such as k-nearest neighbour graphs or KNN) with gene expression values as node features to learn high-quality cell representations in a lower-dimensional space, with applications in downstream analyses like (density-based) clustering, visualisation, gene set enrichment analysis and others. CellVGAE leverages both the variational graph autoencoder and graph attention networks to offer a powerful and more interpretable machine learning approach. It is implemented in PyTorch using the PyTorch Geometric library.

Requirements

Installing CellVGAE with pip will attempt to install PyTorch and PyTorch Geometric, however it is recommended that the appropriate GPU/CPU versions are installed manually beforehand. For Linux:

  1. Install PyTorch GPU:

    conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch -c nvidia

    or PyTorch CPU:

    conda install pytorch torchvision torchaudio cpuonly -c pytorch

  2. Install PyTorch Geometric:

    conda install pyg -c pyg -c conda-forge

  3. (Optional) Install Faiss CPU:

    conda install -c pytorch faiss-cpu

    Faiss is only required if using the option --graph_type "KNN Faiss" . It is a soft dependency as it is not available for some platforms (currently Apple M1). Attempting to use CellVGAE with Faiss without installing it will result in an exception.

    A GPU version of Faiss for CUDA 11.1 is not yet available.

  4. Install CellVGAE with pip:

    pip install cellvgae --pre

  5. (Optional) For the attention graph visualisations of Figure 6, igraph is required:

    pip install python-igraph

If using the R preprocessing code, we recommend installing the following:

Seurat 3, scran, SingleCellExperiment. scRNAseq, BiocSingular, igraph, dplyr and textshape.

Example use

Using the example files in this repo (.h5ad file is the same as downloaded by Scanpy 1.8.1):

python -m cellvgae --input_gene_expression_path "example_data/paul15_myeloid_scanpy.h5ad" --graph_file_path "example_data/paul15_Faiss_KNN_K3_KHVG2500.txt" --graph_convolution "GAT" --num_hidden_layers 2 --hidden_dims 128 128 --num_heads 3 3 3 3 --dropout 0.4 0.4 0.4 0.4 --latent_dim 50 --epochs 50 --model_save_path "model_saved_out"

Other examples are available in examples/cellvgae_example_scripts.txt

(also consult the help section below)

Usage

Invoke the training script with python -m cellvgae with the arguments detailed below:

usage: train [-h] [--input_gene_expression_path INPUT_GENE_EXPRESSION_PATH] [--hvg HVG] [--khvg KHVG] [--graph_type {KNN Scanpy,KNN Faiss,PKNN}] [--k K] [--graph_n_pcs GRAPH_N_PCS]
             [--graph_metric {euclidean,manhattan,cosine}] [--graph_distance_cutoff_num_stds GRAPH_DISTANCE_CUTOFF_NUM_STDS] [--save_graph] [--raw_counts] [--faiss_gpu]
             [--hvg_file_path HVG_FILE_PATH] [--khvg_file_path KHVG_FILE_PATH] [--graph_file_path GRAPH_FILE_PATH] [--graph_convolution {GAT,GATv2,GCN}] [--num_hidden_layers {2,3}]
             [--num_heads [NUM_HEADS [NUM_HEADS ...]]] [--hidden_dims [HIDDEN_DIMS [HIDDEN_DIMS ...]]] [--dropout [DROPOUT [DROPOUT ...]]] [--latent_dim LATENT_DIM] [--loss {kl,mmd}] [--lr LR]
             [--epochs EPOCHS] [--val_split VAL_SPLIT] [--test_split TEST_SPLIT] [--transpose_input] [--use_linear_decoder] [--decoder_nn_dim1 DECODER_NN_DIM1] [--name NAME] --model_save_path MODEL_SAVE_PATH [--umap] [--hdbscan]

Train CellVGAE.

optional arguments:
  -h, --help            show this help message and exit
  --input_gene_expression_path INPUT_GENE_EXPRESSION_PATH
                        Input gene expression file path.
  --hvg HVG             Number of HVGs.
  --khvg KHVG           Number of KHVGs.
  --graph_type {KNN Scanpy,KNN Faiss,PKNN}
                        Type of graph.
  --k K                 K for KNN or Pearson (PKNN) graph.
  --graph_n_pcs GRAPH_N_PCS
                        Use this many Principal Components for the KNN (only Scanpy).
  --graph_metric {euclidean,manhattan,cosine}
  --graph_distance_cutoff_num_stds GRAPH_DISTANCE_CUTOFF_NUM_STDS
                        Number of standard deviations to add to the mean of distances/correlation values. Can be negative.
  --save_graph          Save the generated graph to the output path specified by --model_save_path.
  --raw_counts          Enable preprocessing recipe for raw counts.
  --faiss_gpu           Use Faiss on the GPU (only for KNN Faiss).
  --hvg_file_path HVG_FILE_PATH
                        HVG file if not using command line options to generate it.
  --khvg_file_path KHVG_FILE_PATH
                        KHVG file if not using command line options to generate it. Can be the same file as --hvg_file_path if HVG = KHVG.
  --graph_file_path GRAPH_FILE_PATH
                        Graph specified as an edge list (one edge per line, nodes separated by whitespace, not comma), if not using command line options to generate it.
  --graph_convolution {GAT,GATv2,GCN}
  --num_hidden_layers {2,3}
                        Number of hidden layers (must be 2 or 3).
  --num_heads [NUM_HEADS [NUM_HEADS ...]]
                        Number of attention heads for each layer. Input is a list that must match the total number of layers = num_hidden_layers + 2 in length.
  --hidden_dims [HIDDEN_DIMS [HIDDEN_DIMS ...]]
                        Output dimension for each hidden layer. Input is a list that matches --num_hidden_layers in length.
  --dropout [DROPOUT [DROPOUT ...]]
                        Dropout for each layer. Input is a list that must match the total number of layers = num_hidden_layers + 2 in length.
  --latent_dim LATENT_DIM
                        Latent dimension (output dimension for node embeddings).
  --loss {kl,mmd}       Loss function (KL or MMD).
  --lr LR               Learning rate for Adam.
  --epochs EPOCHS       Number of training epochs.
  --val_split VAL_SPLIT
                        Validation split e.g. 0.1.
  --test_split TEST_SPLIT
                        Test split e.g. 0.1.
  --transpose_input     Specify if inputs should be transposed.
  --use_linear_decoder  Turn on a neural network decoder, similar to traditional VAEs.
  --decoder_nn_dim1 DECODER_NN_DIM1
                        First hidden dimenson for the neural network decoder, if specified using --use_linear_decoder.
  --name NAME           Name used for the written output files.
  --model_save_path MODEL_SAVE_PATH
                        Path to save PyTorch model and output files. Will create the entire path if necessary.
  --umap                Compute and save the 2D UMAP embeddings of the output node features.
  --hdbscan             Compute and save different HDBSCAN clusterings.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cellvgae-0.0.1b4.tar.gz (102.2 MB view details)

Uploaded Source

Built Distribution

cellvgae-0.0.1b4-py3-none-any.whl (18.8 kB view details)

Uploaded Python 3

File details

Details for the file cellvgae-0.0.1b4.tar.gz.

File metadata

  • Download URL: cellvgae-0.0.1b4.tar.gz
  • Upload date:
  • Size: 102.2 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.2 CPython/3.8.12

File hashes

Hashes for cellvgae-0.0.1b4.tar.gz
Algorithm Hash digest
SHA256 e32c8aae3200a3a1a1c20dfdf0a5c044fa521c5d131841750bc2516f2e5b5293
MD5 3328bdb02fee4715ecd4567819d57b87
BLAKE2b-256 52999940cff452bb187aae7c89f57a823112b81547308a68558add2b6d0a6f73

See more details on using hashes here.

File details

Details for the file cellvgae-0.0.1b4-py3-none-any.whl.

File metadata

  • Download URL: cellvgae-0.0.1b4-py3-none-any.whl
  • Upload date:
  • Size: 18.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.2 CPython/3.8.12

File hashes

Hashes for cellvgae-0.0.1b4-py3-none-any.whl
Algorithm Hash digest
SHA256 aacbfa6637ad8eef527a9d466aceec4ea0ac59a889b613ad2eca83ecabd17379
MD5 36046fde6f0ce3d19067ab62f072ae01
BLAKE2b-256 9d2225e1f954f24ccdc9f6ddb538ba6f758c79e2a17e50abfed92340278c4532

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page