TxGNN
Project description
TxGNN: Repurposing therapeutics for neglected diseases using geometric deep learning
This repository hosts the official implementation of TxGNN, a method that can predict drug efficacy to disease with limited molecular underpinnings and few treatments by applying geomtric learning on multi-scale disease knowledge graph.
Installation
Create your virtual environment using virtualenv
or conda
and then do pip install TxGNN
Core API Interface
Using the API, you can (1) reproduce the results in our paper and (2) train TxGNN on your own drug repurposing dataset using a few lines of code, and also generate graph explanations.
from TxGNN import TxData, TxGNN, TxEval
# Download/load knowledge graph dataset
TxData = TxData(data_folder_path = './data')
TxData.prepare_split(split = 'complex_disease', seed = 42)
TxGNN = TxGNN(data = TxData,
weight_bias_track = False,
proj_name = 'TxGNN',
exp_name = 'TxGNN'
)
# Initialize a new model
TxGNN.model_initialize(n_hid = 100,
n_inp = 100,
n_out = 100,
proto = True,
proto_num = 3,
attention = False,
sim_measure = 'all_nodes_profile',
bert_measure = 'disease_name',
agg_measure = 'rarity',
num_walks = 200,
walk_mode = 'bit',
path_length = 2)
Instead of initializing a new model, you can also load a saved model:
TxGNN.load_pretrained('./model_ckpt')
To do pre-training using link prediction for all edge types, you can type:
TxGNN.pretrain(n_epoch = 2,
learning_rate = 1e-3,
batch_size = 1024,
train_print_per_n = 20)
Lastly, to do finetuning on drug-disease relation with metric learning, you can type:
TxGNN.finetune(n_epoch = 500,
learning_rate = 5e-4,
train_print_per_n = 5,
valid_per_n = 20,
save_name = finetune_result_path)
To save the trained model, you can type:
TxGNN.save_model('./model_ckpt')
To evaluate the model on the entire test set using disease-centric evaluation, you can type:
result = TxEval.eval_disease_centric(disease_idxs = 'test_set',
show_plot = False,
verbose = True,
save_result = True,
return_raw = False,
save_name = 'SAVE_PATH')
If you want to look at specific disease, you can also do:
result = TxEval.eval_disease_centric(disease_idxs = [9907.0, 12787.0],
relation = 'indication',
save_result = False)
After training a satisfying link prediction model, we can also train graph XAI model by:
TxGNN.train_graphmask(relation = 'indication',
learning_rate = 3e-4,
allowance = 0.005,
epochs_per_layer = 3,
penalty_scaling = 1,
valid_per_n = 20)
You can retrieve and save the graph XAI gates (whether or not an edge is important) into a pkl file located as SAVED_PATH/'graphmask_output_RELATION.pkl'
:
gates = TxGNN.retrieve_save_gates('SAVED_PATH')
Of course, you can save and load graphmask model as well via:
TxGNN.save_graphmask_model('./graphmask_model_ckpt')
TxGNN.load_pretrained_graphmask('./graphmask_model_ckpt')
Cite Us
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.