Skip to main content

TensorFlow 2.x Bayesian Neural Network for Survival Analysis

Project description

Uncertainty Estimation in Deep Bayesian Survival Models

*UPDATE 11/16/23: pip package now available. Use "pip install bnnsurv". Tested with TensorFlow 2.13 and TensorFlow Probability 0.21. See test file for how to use.

This repository is the official TensorFlow implementation of Uncertainty Estimation in Deep Bayesian Survival Models, BHI 2023.

The proposed method is implemented based on TensorFlow Probability.

Full paper is available on IEEE Xplore: https://ieeexplore.ieee.org/document/10313466

In this work, we introduce the use of Bayesian inference techniques for survival analysis in neural networks that rely on the Cox’s proportional hazard assumption, for which we discuss a new flexible and effective architecture. We implement three architectures: a fully-deterministic neural network that acts as a baseline, a Bayesian model using variational inference and one using Monte-Carlo Dropout.

Experiments show that the Bayesian models improve predictive performance over SOTA neural networks in a test dataset with few samples (WHAS500, 500 samples) and provide comparable performance in two larger ones (SEER and SUPPORT, 4024 and 8873 samples, respectively)

License

To view the license for this work, visit https://github.com/thecml/UE-BNNSurv/blob/main/LICENSE

Requirements

To run the models, please refer to Requirements.txt.

Install auton-survival manually from Git:

pip install git+https://github.com/autonlab/auton-survival.git

Code was tested in virtual environment with Python 3.8, TensorFlow 2.11 and TensorFlow Probability 0.19

Training

  • Make directories mkdir results and mkdir models.

  • Please refer to paths.py to set appropriate paths. By default, results are in results and models in models

  • Network configuration using best hyperparameters are found in configs/*

  • Run the training code:

# SOTA models
python train_sota_models.py

# BNN Models
python train_bnn_models.py

Evaluation

  • After model training, view the results in the results folder.

Visualization

  • Run the visualization notebooks:
jupyter notebook plot_survival_curves.ipynb
jupyter notebook plot_survival_time.ipynb

Citation

@inproceedings{lillelund_uncertainty_2023,
  author={Lillelund, Christian Marius and Magris, Martin and Pedersen, Christian Fischer},
  booktitle={2023 IEEE EMBS International Conference on Biomedical and Health Informatics (BHI)}, 
  title={Uncertainty Estimation in Deep Bayesian Survival Models}, 
  year={2023},
  pages={1-4},
  doi={10.1109/BHI58575.2023.10313466}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bnnsurv-0.1.3.tar.gz (10.5 kB view details)

Uploaded Source

Built Distribution

bnnsurv-0.1.3-py3-none-any.whl (9.6 kB view details)

Uploaded Python 3

File details

Details for the file bnnsurv-0.1.3.tar.gz.

File metadata

  • Download URL: bnnsurv-0.1.3.tar.gz
  • Upload date:
  • Size: 10.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.3

File hashes

Hashes for bnnsurv-0.1.3.tar.gz
Algorithm Hash digest
SHA256 822843d59b73743d9d97865bc2782bac2092dac37ba85f787ec8175db6045a68
MD5 aeb4767e1eb9059c7351ac53c9f0208f
BLAKE2b-256 59ee2148113cc2906e3bcd1759056ac59d59bf53e700a669c998f5c06dc6d122

See more details on using hashes here.

File details

Details for the file bnnsurv-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: bnnsurv-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 9.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.3

File hashes

Hashes for bnnsurv-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 5e3186a0af3b657ebddb7ec99f3f0044d7fc97e576aa2e748dd08f0eb1f4ac68
MD5 ae483b4c7cf7c61a1e9f6819b875d7aa
BLAKE2b-256 bec66ba1edf07f3a1497e7ff452f21d2af0171caa7b7e706480fe173f91abe49

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page