TensorFlow 2.x Bayesian Neural Network for Survival Analysis
Project description
Uncertainty Estimation in Deep Bayesian Survival Models
*UPDATE 11/16/23: pip package now available. Use "pip install bnnsurv". Tested with TensorFlow 2.13 and TensorFlow Probability 0.21. See test file for how to use.
This repository is the official TensorFlow implementation of Uncertainty Estimation in Deep Bayesian Survival Models, BHI 2023.
The proposed method is implemented based on TensorFlow Probability.
Full paper is available on IEEE Xplore: https://ieeexplore.ieee.org/document/10313466
In this work, we introduce the use of Bayesian inference techniques for survival analysis in neural networks that rely on the Cox’s proportional hazard assumption, for which we discuss a new flexible and effective architecture. We implement three architectures: a fully-deterministic neural network that acts as a baseline, a Bayesian model using variational inference and one using Monte-Carlo Dropout.
Experiments show that the Bayesian models improve predictive performance over SOTA neural networks in a test dataset with few samples (WHAS500, 500 samples) and provide comparable performance in two larger ones (SEER and SUPPORT, 4024 and 8873 samples, respectively)
License
To view the license for this work, visit https://github.com/thecml/UE-BNNSurv/blob/main/LICENSE
Requirements
To run the models, please refer to Requirements.txt.
Install auton-survival manually from Git:
pip install git+https://github.com/autonlab/auton-survival.git
Code was tested in virtual environment with Python 3.8, TensorFlow 2.11 and TensorFlow Probability 0.19
Training
-
Make directories
mkdir resultsandmkdir models. -
Please refer to
paths.pyto set appropriate paths. By default, results are inresultsand models inmodels -
Network configuration using best hyperparameters are found in
configs/* -
Run the training code:
# SOTA models
python train_sota_models.py
# BNN Models
python train_bnn_models.py
Evaluation
- After model training, view the results in the
resultsfolder.
Visualization
- Run the visualization notebooks:
jupyter notebook plot_survival_curves.ipynb
jupyter notebook plot_survival_time.ipynb
Citation
@inproceedings{lillelund_uncertainty_2023,
author={Lillelund, Christian Marius and Magris, Martin and Pedersen, Christian Fischer},
booktitle={2023 IEEE EMBS International Conference on Biomedical and Health Informatics (BHI)},
title={Uncertainty Estimation in Deep Bayesian Survival Models},
year={2023},
pages={1-4},
doi={10.1109/BHI58575.2023.10313466}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file bnnsurv-0.1.3.tar.gz.
File metadata
- Download URL: bnnsurv-0.1.3.tar.gz
- Upload date:
- Size: 10.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.8.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
822843d59b73743d9d97865bc2782bac2092dac37ba85f787ec8175db6045a68
|
|
| MD5 |
aeb4767e1eb9059c7351ac53c9f0208f
|
|
| BLAKE2b-256 |
59ee2148113cc2906e3bcd1759056ac59d59bf53e700a669c998f5c06dc6d122
|
File details
Details for the file bnnsurv-0.1.3-py3-none-any.whl.
File metadata
- Download URL: bnnsurv-0.1.3-py3-none-any.whl
- Upload date:
- Size: 9.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.8.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5e3186a0af3b657ebddb7ec99f3f0044d7fc97e576aa2e748dd08f0eb1f4ac68
|
|
| MD5 |
ae483b4c7cf7c61a1e9f6819b875d7aa
|
|
| BLAKE2b-256 |
bec66ba1edf07f3a1497e7ff452f21d2af0171caa7b7e706480fe173f91abe49
|