Skip to main content

A numerically-stable and differentiable implementation of the Truncated Gaussian distribution in Pytorch.

Project description

PyPI version Downloads

Stable Truncated Gaussian

A differentiable implementation of the Truncated Gaussian (Normal) distribution using Python and Pytorch, which is numerically stable even when the μ parameter lies outside the interval [a,b] given by the bounds of the distribution. In this situation, a naive evaluation of the mean, variance and log-probability of the distribution could otherwise result in catastrophic cancellation. Our code is inspired by TruncatedNormal.jl and torch_truncnorm. Currently, we provide numerically-stable methods for calculating the mean, variance, log-probability, KL-divergence and sampling from the distribution. Our current implementation of icdf (which is used for sampling from the distribution) still needs some work for those situations where the [a,b] interval is small. For a comparison between our icdf implementation and the one provided by scipy, take a look at the images_cdf_comparison folder.

Installation

Simply install with pip:

pip install stable-trunc-gaussian

Example

Run the following code in Python:

from stable_trunc_gaussian import TruncatedGaussian as TG
from torch import tensor as t

# Create a Truncated Gaussian with mu=0, sigma=1, a=10, b=11
# Notice how mu is outside the interval [a,b]
dist = TG(t(0),t(1),t(10),t(11))

print("Mean:", dist.mean)
print("Variance:", dist.variance)
print("Log-prob(10.5):", dist.log_prob(t(10.5)))

Result:

Mean: tensor(10.0981)
Variance: tensor(0.0094)
Log-prob(10.5): tensor(-2.8126)

Parallel vs Sequential Implementation

The class obtained by doing from stable_trunc_gaussian import TruncatedGaussian corresponds to a parallel implementation of the truncated gaussian, which makes possible to obtain several values (mean, variance and log-probs) in parallel. In case you are only interested in computing values sequentially, i.e., one at a time, we also provide a sequential implementation which results more efficient only for this case. In order to use this sequential implementation, simply do from stable_trunc_gaussian import SeqTruncatedGaussian. Here is an example:

from stable_trunc_gaussian import TruncatedGaussian, SeqTruncatedGaussian
from torch import tensor as t

# Parallel computation
means = TruncatedGaussian(t([0,0.5]),t(1,1),t(-1,2),t(1,5)).mean

# Sequential computation
# Note: the 'TruncatedGaussian' class can also be used for this sequential case
mean_0 = SeqTruncatedGaussian(t([0]),t(1),t(-1),t(1)).mean
mean_1 = SeqTruncatedGaussian(t([0.5]),t(1),t(2),t(5)).mean

Acknowledgements

We want to thank users KFrank and ptrblck for their help in solving the bug when computing the gradients for the parallel version (bug solved in version 1.1.1).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

stable-trunc-gaussian-1.3.1.tar.gz (15.2 kB view details)

Uploaded Source

Built Distribution

stable_trunc_gaussian-1.3.1-py3-none-any.whl (17.0 kB view details)

Uploaded Python 3

File details

Details for the file stable-trunc-gaussian-1.3.1.tar.gz.

File metadata

  • Download URL: stable-trunc-gaussian-1.3.1.tar.gz
  • Upload date:
  • Size: 15.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.9

File hashes

Hashes for stable-trunc-gaussian-1.3.1.tar.gz
Algorithm Hash digest
SHA256 7fdb249a8ca1000e40b9167832ee46f1aee55ad30e9978f0eb2aabe9060ff5f3
MD5 f7035c62851a11487fd28444d7f0797c
BLAKE2b-256 9c503114976ae7226d100ee63a282dbbf9e2d0ba1b56309081afb1607a81ac4d

See more details on using hashes here.

File details

Details for the file stable_trunc_gaussian-1.3.1-py3-none-any.whl.

File metadata

File hashes

Hashes for stable_trunc_gaussian-1.3.1-py3-none-any.whl
Algorithm Hash digest
SHA256 afc3b0c77b5304ea2eae5b4e858df92ac8ad7cafcf3db4c3c7a4f2102187de7c
MD5 feb2dd4cb847a3e8deaa3d9825440a66
BLAKE2b-256 87ca7c2c3cad605badc1ec3c8d7d4c52c84628ca2eb9fe630490fdbc2e5e2071

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page