Skip to main content

Extended functionality for univariate probability distributions in PyTorch

Project description

survival_distributions

This package extends the functionality of univariate distributions in torch.distributions by implementing several new methods:

  • sf: survival function (complementary CDF)
  • logsf: logarithm of the survival function (negative cumulative hazard function)
  • logcdf: logarithm of the CDF
  • log_hazard: logarithm of the hazard function (logarithm of the failure rate)
  • isf: inverse of the survival function
  • sample_cond: instead of sampling from the full support of the distribution, generate samples between lower_bound and upper_bound

This is especially useful when working with temporal point processes or survival analysis.

Naive implementation based on existing PyTorch functionality (e.g., torch.log(1.0 - dist.cdf(x)) for logsf) will often not be as accurate and numerically stable as the implementation provided by survival_distributions. Hopefully, these methods will be implemented in PyTorch sometime in the future, but this package provides an alternative for the time being.

See DISTRIBUTIONS.md for more details about the implemented functions and supported distributions.

Installation

  1. Install the latest version of PyTorch.
  2. Install survival_distributions
pip install git+https://github.com/shchur/survival-distributions.git

Supported distributions

Numerically stable implementation

For these distributions we provide a numerically stable implementation of logsf.

  • Exponential
  • Logistic
  • LogLogistic
  • MixtureSameFamily
  • TransformedDistribution
  • Uniform
  • Weibull

Naive implementation

For these distributions we implement logsf(x) as log(1.0 - dist.cdf(x)), which is less numerically stable.

  • LogNormal
  • Normal

Usage

The package provides a drop-in replacement for torch.distributions, so you can just modify your code as follows.

Old code

import torch

dist = torch.distributions.Exponential(rate=torch.tensor(2.0))
x = torch.tensor(1.5)

log_survival_proba = torch.log(1.0 - dist.cdf(x))

New code

import torch
import survival_distributions as sd

dist = sd.Exponential(rate=torch.tensor(2.0))
x = torch.tensor(1.5)

log_survival_proba = dist.logsf(x)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

survival_distributions-0.0.2.tar.gz (8.5 kB view details)

Uploaded Source

File details

Details for the file survival_distributions-0.0.2.tar.gz.

File metadata

  • Download URL: survival_distributions-0.0.2.tar.gz
  • Upload date:
  • Size: 8.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.11.3 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.64.0 CPython/3.9.12

File hashes

Hashes for survival_distributions-0.0.2.tar.gz
Algorithm Hash digest
SHA256 2f85ef61b23cf90be8d92733479263227e80c942209bb0b69ba455746c701ddf
MD5 0dd2292c6e990c2ec3c014c69154e0a1
BLAKE2b-256 de8d9b6638dd35c4481e9eb11ee8639ca86518402393f13399f8b7984b00baec

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page