Extended functionality for univariate probability distributions in PyTorch
Project description
survival_distributions
This package extends the functionality of univariate distributions in torch.distributions
by implementing several new methods:
sf: survival function (complementary CDF)logsf: logarithm of the survival function (negative cumulative hazard function)logcdf: logarithm of the CDFlog_hazard: logarithm of the hazard function (logarithm of the failure rate)isf: inverse of the survival functionsample_cond: instead of sampling from the full support of the distribution, generate samples betweenlower_boundandupper_bound
This is especially useful when working with temporal point processes or survival analysis.
Naive implementation based on existing PyTorch functionality (e.g.,
torch.log(1.0 - dist.cdf(x)) for logsf) will often not be as accurate and numerically
stable as the implementation provided by survival_distributions.
Hopefully, these methods will be implemented in PyTorch sometime in the future,
but this package provides an alternative for the time being.
See DISTRIBUTIONS.md for more details about the implemented functions and supported distributions.
Installation
- Install the latest version of PyTorch.
- Install
survival_distributionspip install survival_distributions
Supported distributions
Numerically stable implementation
For these distributions we provide a numerically stable implementation of logsf.
ExponentialLogisticLogLogisticMixtureSameFamilyTransformedDistributionUniformWeibull
Naive implementation
For these distributions we implement logsf(x) as log(1.0 - dist.cdf(x)), which is less
numerically stable.
LogNormalNormal
Usage
The package provides a drop-in replacement for torch.distributions, so you can just modify your code as follows.
Old code
import torch
dist = torch.distributions.Exponential(rate=torch.tensor(2.0))
x = torch.tensor(1.5)
log_survival_proba = torch.log(1.0 - dist.cdf(x))
New code
import torch
import survival_distributions as sd
dist = sd.Exponential(rate=torch.tensor(2.0))
x = torch.tensor(1.5)
log_survival_proba = dist.logsf(x)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file survival_distributions-0.0.3.tar.gz.
File metadata
- Download URL: survival_distributions-0.0.3.tar.gz
- Upload date:
- Size: 8.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.11.3 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.64.0 CPython/3.9.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b0b01d242cb950cc12aed50bef61b1bed1b4a060316b0a037a245547afbd4d92
|
|
| MD5 |
458e1a8dac810992570a48cf67778cde
|
|
| BLAKE2b-256 |
e2e8dfe6d26a7d5f073c389fcd3ebedcf9bbf57b635facaffecf8811ca53735c
|