Skip to main content

Deep survival analysis made easy with pytorch

Project description

Deep survival analysis made easy

Python PyPI - Version Conda PyPI Downloads Conda Downloads

CodeQC Docs CodeFactor JOSS License Documentation

TorchSurv is a Python package that serves as a companion tool to perform deep survival modeling within the PyTorch environment. Unlike existing libraries that impose specific parametric forms on users, TorchSurv enables the use of custom PyTorch-based deep survival models. With its lightweight design, minimal input requirements, full PyTorch backend, and freedom from restrictive survival model parameterizations, TorchSurv facilitates efficient survival model implementation, particularly beneficial for high-dimensional input data scenarios.

If you find this repository useful, please consider giving a star! ⭐

This package was developed by Novartis and the US Food and Drug Administration as part of a research collaboration agreement on radiogenomics.

TL;DR

Our idea is to keep things simple. You are free to use any model architecture you want! Our code has 100% PyTorch backend and behaves like any other functions (losses or metrics) you may be familiar with.

Our functions are designed to support you, not to make you jump through hoops. Here's a pseudo code illustrating how easy is it to use TorchSurv to fit and evaluate a Cox proportional hazards model:

from torchsurv.loss import cox
from torchsurv.metrics.cindex import ConcordanceIndex

# Pseudo training loop
for data in dataloader:
    x, event, time = data
    estimate = model(x)  # shape = torch.Size([64, 1]), if batch size is 64
    loss = cox.neg_partial_log_likelihood(estimate, event, time)
    loss.backward()  # native torch backend

# You can check model performance using our evaluation metrics, e.g, the concordance index with
cindex = ConcordanceIndex()
cindex(estimate, event, time)

# You can obtain the confidence interval of the c-index
cindex.confidence_interval()

# You can test whether the observed c-index is greater than 0.5 (random estimator)
cindex.p_value(method="noether", alternative="two_sided")

# You can even compare the metrics between two models (e.g., vs. model B)
cindex.compare(cindexB)

Installation and dependencies

First, install the package using either PyPI or Conda

  • Using conda (recommended)
conda install conda-forge::torchsurv
  • Using PyPI
pip install torchsurv
  • Using for local installation (latest version)
git clone <repo>
cd <repo>
pip install -e .

Additionally, to build the documentation (notebooks, sphinx) and for package development (tests), please see the development notes and dev/environment.yml. This step is not required to use TorchSurv in your projects but only for optional features.

Getting started

We recommend starting with the introductory guide, where you'll find an overview of the package's functionalities.

Survival data

We simulate a random batch of 64 subjects. Each subject is associated with a binary event status (= True if event occurred), a time-to-event or censoring and 16 covariates.

>>> import torch
>>> _ = torch.manual_seed(52)
>>> n = 64
>>> x = torch.randn((n, 16))
>>> event = torch.randint(low=0, high=2, size=(n,)).bool()
>>> time = torch.randint(low=1, high=100, size=(n,)).float()

Cox proportional hazards model

The user is expected to have defined a model that outputs the estimated log relative hazard for each subject. For illustrative purposes, we define a simple linear model that generates a linear combination of the covariates.

>>> from torch import nn
>>> model_cox = nn.Sequential(nn.Linear(16, 1))
>>> log_hz = model_cox(x)
>>> print(log_hz.shape)
torch.Size([64, 1])

Given the estimated log relative hazard and the survival data, we calculate the current loss for the batch with:

>>> from torchsurv.loss.cox import neg_partial_log_likelihood
>>> loss = neg_partial_log_likelihood(log_hz, event, time)
>>> print(loss)
tensor(4.1723, grad_fn=<DivBackward0>)

We obtain the concordance index for this batch with:

>>> from torchsurv.metrics.cindex import ConcordanceIndex
>>> with torch.no_grad(): log_hz = model_cox(x)
>>> cindex = ConcordanceIndex()
>>> print(cindex(log_hz, event, time))
tensor(0.4872)

We obtain the Area Under the Receiver Operating Characteristic Curve (AUC) at a new time t = 50 for this batch with:

>>> from torchsurv.metrics.auc import Auc
>>> new_time = torch.tensor(50.)
>>> auc = Auc()
>>> print(auc(log_hz, event, time, new_time=50))
tensor([0.4737])

Weibull accelerated failure time (AFT) model

The user is expected to have defined a model that outputs for each subject the estimated log scale and optionally the log shape of the Weibull distribution that the event density follows. In case the model has a single output, TorchSurv assume that the shape is equal to 1, resulting in the event density to be an exponential distribution solely parametrized by the scale.

For illustrative purposes, we define a simple linear model that estimate two linear combinations of the covariates (log scale and log shape parameters).

>>> from torch import nn
>>> model_weibull = nn.Sequential(nn.Linear(16, 2))
>>> log_params = model_weibull(x)
>>> print(log_params.shape)
torch.Size([64, 2])

Given the estimated log scale and log shape and the survival data, we calculate the current loss for the batch with:

>>> from torchsurv.loss.weibull import neg_log_likelihood
>>> loss = neg_log_likelihood(log_params, event, time)
>>> print(loss)
tensor(82931.5078, grad_fn=<DivBackward0>)

To evaluate the predictive performance of the model, we calculate subject-specific log hazard and survival function evaluated at all times with:

>>> from torchsurv.loss.weibull import log_hazard
>>> from torchsurv.loss.weibull import survival_function
>>> with torch.no_grad(): log_params = model_weibull(x)
>>> log_hz = log_hazard(log_params, time)
>>> print(log_hz.shape)
torch.Size([64, 64])
>>> surv = survival_function(log_params, time)
>>> print(surv.shape)
torch.Size([64, 64])

We obtain the concordance index for this batch with:

>>> from torchsurv.metrics.cindex import ConcordanceIndex
>>> cindex = ConcordanceIndex()
>>> print(cindex(log_hz, event, time))
tensor(0.4062)

We obtain the AUC at a new time t = 50 for this batch with:

>>> from torchsurv.metrics.auc import Auc
>>> new_time = torch.tensor(50.)
>>> log_hz_t = log_hazard(log_params, time=new_time)
>>> auc = Auc()
>>> print(auc(log_hz_t, event, time, new_time=new_time))
tensor([0.3509])

We obtain the integrated brier-score with:

>>> from torchsurv.metrics.brier_score import BrierScore
>>> brier_score = BrierScore()
>>> bs = brier_score(surv, event, time)
>>> print(brier_score.integral())
tensor(0.4447)

Related Packages

The table below compares the functionalities of TorchSurv with those of auton-survival, pycox, torchlife, scikit-survival, lifelines, and deepsurv. While several libraries offer survival modelling functionalities, no existing library provides the flexibility to use a custom PyTorch-based neural networks to define the survival model parameters.

The outputs of both the log-likelihood functions and the evaluation metrics functions have undergone thorough comparison with benchmarks generated using Python and R packages. The comparisons (at time of publication) are summarised in the Related packages summary.

Survival analysis libraries in Python Survival analysis libraries in Python

Survival analysis libraries in R. For obtaining the evaluation metrics, packages survival, riskRegression, SurvMetrics and pec require the fitted model object as input (a specific object format) and RisksetROC imposes a smoothing method. Packages timeROC, riskRegression and pec force the user to choose a form for subject-specific weights (e.g., inverse probability of censoring weighting (IPCW)). Packages survcomp and SurvivalROC do not implement the general AUC but the censoring-adjusted AUC estimator proposed by Heagerty et al. (2000).

Survival analysis libraries in R

Contributing

We value contributions from the community to enhance and improve this project. If you'd like to contribute, please consider the following:

  1. Create Issues: If you encounter bugs, have feature requests, or want to suggest improvements, please create an issue in the GitHub repository. Make sure to provide detailed information about the problem, including code for reproducibility, or enhancement you're proposing.

  2. Fork and Pull Requests: If you're willing to address an existing issue or contribute a new feature, fork the repository, create a new branch, make your changes, and then submit a pull request. Please ensure your code follows our coding conventions and include tests for any new functionality.

By contributing to this project, you agree to license your contributions under the same license as this project.

Contacts

If you have any questions, suggestions, or feedback, feel free to reach out the development team us.

Cite

If you use this project in academic work or publications, we appreciate citing it using the following BibTeX entry:

@article{Monod2024,
    doi = {10.21105/joss.07341},
    url = {https://doi.org/10.21105/joss.07341},
    year = {2024},
    publisher = {The Open Journal},
    volume = {9},
    number = {104},
    pages = {7341},
    author = {Mélodie Monod and Peter Krusche and Qian Cao and Berkman Sahiner and Nicholas Petrick and David Ohlssen and Thibaud Coroller},
    title = {TorchSurv: A Lightweight Package for Deep Survival Analysis}, journal = {Journal of Open Source Software}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchsurv-0.1.5.tar.gz (62.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchsurv-0.1.5-py3-none-any.whl (52.8 kB view details)

Uploaded Python 3

File details

Details for the file torchsurv-0.1.5.tar.gz.

File metadata

  • Download URL: torchsurv-0.1.5.tar.gz
  • Upload date:
  • Size: 62.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.18

File hashes

Hashes for torchsurv-0.1.5.tar.gz
Algorithm Hash digest
SHA256 09c340d1b4dc03192d18c486ef4336a1e5b2433c5e88ed390064a61a60669902
MD5 ef12861bb821c90aad4fad5c97d19b18
BLAKE2b-256 c5642c51fb7c3ad9cf77ae3a83c1aff1e7ca66968a968b0bc5134a02deee69f7

See more details on using hashes here.

File details

Details for the file torchsurv-0.1.5-py3-none-any.whl.

File metadata

  • Download URL: torchsurv-0.1.5-py3-none-any.whl
  • Upload date:
  • Size: 52.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.18

File hashes

Hashes for torchsurv-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 e41eacccccb561d33b5754928ea3acb84ae820c017f8da48b405cf14177af27e
MD5 9911bb49a7504ce5823e32df6c345f83
BLAKE2b-256 cc7fed18272ac49e7b8294bb6ee891c99a55d25328853fa9629627d7763067dc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page