A python implementation of Gene-SGAN for multi-view semi-supervised clustering
Project description
Gene-SGAN
Gene-SGAN is a multi-view semi-supervised clustering method for disentangling disease heterogeneity. By jointly considering brain phenotypic and genetic data, Gene-SGAN identify disease subtypes with associated phenotypic and genetic signatures. Using healthy control (HC) populations as a referece distribution, the model effectively cluster participants based disease-related phenotypic variations with genetic associations, thus avoiding confounders from disease-unrelated factors.
License
Copyright (c) 2016 University of Pennsylvania. All rights reserved. See https://www.cbica.upenn.edu/sbia/software/license.html
Installation
We highly recommend the users to install Anaconda3 on your machine. After installing Anaconda3, Smile-GAN can be used following this procedure:
We recommend the users to use the Conda virtual environment:
$ conda create --name genesgan python=3.8
Activate the virtual environment
$ conda activate genesgan
Install GeneSGAN from PyPi:
$ pip install GeneSGAN
Input structure
Main functions of GeneSGAN basically takes three panda dataframes as data inputs, imaging_data, gene_data, and covariate (optional). Columns with name 'participant_id' and diagnosis must exist in imaging_data and covariate. Some conventions for the group label/diagnosis: -1 represents healthy control (HC) and 1 represents patient (PT); categorical variables, such as sex, should be encoded to numbers: Female for 0 and Male for 1, for example.
Genetic features of all PT but not HC participants in the imaging_data need to be provided, so gene_data should not have the column diagnosis. The current package only takes SNP data as genetic features, and each SNP variant in gene_data need to be recoded into 0, 1, or 2 indicating the number of minor alleles.
Example for imaging_data:
participant_id diagnosis ROI1 ROI2 ...
subject-1 -1 325.4 603.4
subject-2 1 260.5 580.3
subject-3 -1 326.5 623.4
subject-4 1 301.7 590.5
subject-5 1 293.1 595.1
subject-6 1 287.8 608.9
Example for gene_data:
participant_id SNP1 SNP2 ...
subject-2 1 0
subject-4 0 0
subject-5 2 0
subject-6 0 2
Example for covariate
participant_id diagnosis age sex ...
subject-1 -1 57.3 0
subject-2 1 43.5 1
subject-3 -1 53.8 1
subject-4 1 56.0 0
subject-5 1 60.0 1
subject-6 1 62.5 0
Example
We offer a toy dataset, the ground truth, and the sample code in the folder GeneSGAN/datasets. One fold training takes around 25 minutes on a MacBook Pro with 1.4GHz Intel Core i5, and could lead to clustering with around 95% accuracy. A larger fold number could contribute to better clustering performances, so 20 folds or above is recommended in real data applications. Multiple folds can be performed in parallel on HPC clusters.
import pandas as pd
from GeneSGAN.Gene_SGAN_clustering import cross_validated_clustering, clustering_result
timage_data = pd.read_csv('toy_data_imaging.csv')
gene_data = pd.read_csv('toy_data_gene.csv')
covariate = pd.read_csv('covariate.csv')
output_dir = "PATH_OUTPUT_DIR"
ncluster = 3
start_saving_epoch = 20000
max_epoch = 30000
## three parameters for stopping threshold
WD = 0.12
AQ = 30
cluster_loss = 0.01
## three hyper-parameter to be tuned
genelr = 0.0002
lam = 9
mu = 5
When using the package, genelr, WD, AQ, cluster_loss, batch_size need to be chosen empirically:
genelr: genelr (i.e., learning rate of gene step) is the most important hyper-parameter of the model. It is necessary to be set to different values, and the value leading to the highest mean N-Asso-SNPs should be used. (Recommended value: 0.0004-0.0001)
WD: Wasserstein Distance measures the distance between generated PT data along each direction and real PT data. (Recommended value: 0.11-0.14)
AQ: Alteration Quantity measures the number of participants who change cluster labels during last three traninig epochs. Low AQ implies convergence of training. (Recommended value: 1/20 of PT sample size)
cluster_loss: Cluster loss measures how well clustering function reconstruct sampled Z variable. (Recommended value: 0.01-0.015)
batch_size: Size of the batch for each training epoch. (Default to be 25) It is necessary to be reset to 1/8 of the PT sample size.
Some other parameters, lam, mu have default values but need to be changed in some cases:
lam: coefficient controlling the relative importance of cluster_loss in the training objective function. (Default to be 9)
mu: coefficient controlling the relative importance of change_loss in the training objective function. (Default to be 5).
fold_number = 50 # number of folds the hold-out cv runs
data_fraction = 0.8 # fraction of data used in each fold
cross_validated_clustering(imaging_data, gene_data, ncluster, fold_number, data_fraction, start_saving_epoch, max_epoch,\
output_dir, WD, AQ, cluster_loss, genelr = genelr, lam = lam, mu = mu, covariate=covariate)
cross_validated_clustering performs clustering with hold-out cross validation. It is the main function for clustering. Since the CV process may take long training time on a normal desktop computer, the function enables early stop and later resumption. Users can set stop_fold to be early stopping point and start_fold depending on previous stopping point. By setting stop_fold to start_fold+1, users can run multiple iterations in parellel, which will significantly reduce the training time.
The function automatically saves an csv file with clustering results. Two metrics are also provided for hyper-parameter selection: the mean ARI value (i.e., agreements of clusters among all folds), and the mean N-Asso-SNPs (i.e., number of SNPs associated with derived subtypes in test sets).
model_dirs = ['PATH_TO_CHECKPOINT1','PATH_TO_CHECKPOINT2',...] #list of paths to previously saved checkpoints (with name 'converged_model_foldk' after cv process)
cluster_label, cluster_probabilities, _, _, _, _ = clustering_result(model_dirs, ncluster, imaging_data, gene_data, covariate = covariate)
clustering_result is a function used for clustering patient data using previously saved models. imaging_data and covariate (optional) should be panda dataframe with same format as introduced before. Only PT data (can be inside or outside of training set), for which the user want to derive subtype memberships, need to be provided with diagnoses set to be 1. gene_data is not required when applying the trained models. The function returns cluster labels of PT data following the order of PT in the provided dataframe.
Citation
If you use this package for research, please cite the following paper:
@misc{yang2023genesgan,
doi = {10.48550/ARXIV.2301.10772},
url = {https://arxiv.org/abs/2301.10772},
author = {Yang, Zhijian and Wen, Junhao and Abdulkadir, Ahmed and Cui, Yuhan and Erus, Guray and Mamourian, Elizabeth and Melhem, Randa and Srinivasan, Dhivya and Govindarajan, Sindhuja T. and Chen, Jiong and Habes, Mohamad and Masters, Colin L. and Maruff, Paul and Fripp, Jurgen and Ferrucci, Luigi and Albert, Marilyn S. and Johnson, Sterling C. and Morris, John C. and LaMontagne, Pamela and Marcus, Daniel S. and Benzinger, Tammie L. S. and Wolk, David A. and Shen, Li and Bao, Jingxuan and Resnick, Susan M. and Shou, Haochang and Nasrallah, Ilya M. and Davatzikos, Christos},
title = {Gene-SGAN: a method for discovering disease subtypes with imaging and genetic signatures via multi-view weakly-supervised deep clustering},
publisher = {arXiv},
year = {2023},
copyright = {arXiv.org perpetual, non-exclusive license}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file GeneSGAN-0.0.2.tar.gz
.
File metadata
- Download URL: GeneSGAN-0.0.2.tar.gz
- Upload date:
- Size: 18.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.6.1 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.50.2 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 26740c79465f8aef898568a63a5da3ed50951346a44ddf1a5aea3df7ff40346c |
|
MD5 | 6ff2e5a3871a65453731a83d10593589 |
|
BLAKE2b-256 | 6641d826f2eccc9a3fce625a6932a40222befef55fe044bffba44d58a08c0074 |