Skip to main content

Extended functionality for univariate probability distributions in PyTorch

Project description

survival_distributions

This package extends the functionality of univariate distributions in torch.distributions by implementing several new methods:

  • sf: survival function (complementary CDF)
  • logsf: logarithm of the survival function (negative cumulative hazard function)
  • logcdf: logarithm of the CDF
  • log_hazard: logarithm of the hazard function (logarithm of the failure rate)
  • isf: inverse of the survival function
  • sample_cond: instead of sampling from the full support of the distribution, generate samples between lower_bound and upper_bound

This is especially useful when working with temporal point processes or survival analysis.

Naive implementation based on existing PyTorch functionality (e.g., torch.log(1.0 - dist.cdf(x)) for logsf) will often not be as accurate and numerically stable as the implementation provided by survival_distributions. Hopefully, these methods will be implemented in PyTorch sometime in the future, but this package provides an alternative for the time being.

See DISTRIBUTIONS.md for more details about the implemented functions and supported distributions.

Installation

  1. Install the latest version of PyTorch.
  2. Install survival_distributions
    pip install survival_distributions
    

Supported distributions

Numerically stable implementation

For these distributions we provide a numerically stable implementation of logsf.

  • Exponential
  • Logistic
  • LogLogistic
  • MixtureSameFamily
  • TransformedDistribution
  • Uniform
  • Weibull

Naive implementation

For these distributions we implement logsf(x) as log(1.0 - dist.cdf(x)), which is less numerically stable.

  • LogNormal
  • Normal

Usage

The package provides a drop-in replacement for torch.distributions, so you can just modify your code as follows.

Old code

import torch

dist = torch.distributions.Exponential(rate=torch.tensor(2.0))
x = torch.tensor(1.5)

log_survival_proba = torch.log(1.0 - dist.cdf(x))

New code

import torch
import survival_distributions as sd

dist = sd.Exponential(rate=torch.tensor(2.0))
x = torch.tensor(1.5)

log_survival_proba = dist.logsf(x)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

survival_distributions-0.0.3.tar.gz (8.4 kB view details)

Uploaded Source

File details

Details for the file survival_distributions-0.0.3.tar.gz.

File metadata

  • Download URL: survival_distributions-0.0.3.tar.gz
  • Upload date:
  • Size: 8.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.11.3 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.64.0 CPython/3.9.12

File hashes

Hashes for survival_distributions-0.0.3.tar.gz
Algorithm Hash digest
SHA256 b0b01d242cb950cc12aed50bef61b1bed1b4a060316b0a037a245547afbd4d92
MD5 458e1a8dac810992570a48cf67778cde
BLAKE2b-256 e2e8dfe6d26a7d5f073c389fcd3ebedcf9bbf57b635facaffecf8811ca53735c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page