Skip to main content

A numerically-stable and differentiable implementation of the Truncated Gaussian distribution in Pytorch.

Project description

`PyPI version Downloads

Stable Truncated Gaussian

A differentiable implementation of the Truncated Gaussian (Normal) distribution using Python and Pytorch, which is numerically stable even when the μ parameter lies outside the interval [a,b] given by the bounds of the distribution. In this situation, a naive evaluation of the mean, variance and log-probability of the distribution could otherwise result in catastrophic cancellation. Our code is inspired by TruncatedNormal.jl and torch_truncnorm. Currently, we only provide functionality for calculating the mean, variance and log-probability, but not for calculating the entropy or sampling from the distribution.

Installation

Simply install with pip:

pip install stable-trunc-gaussian

Example

Run the following code in Python:

from stable_trunc_gaussian import TruncatedGaussian as TG
from torch import tensor as t

# Create a Truncated Gaussian with mu=0, sigma=1, a=10, b=11
# Notice how mu is outside the interval [a,b]
dist = TG(t(0),t(1),t(10),t(11))

print("Mean:", dist.mean)
print("Variance:", dist.variance)
print("Log-prob(10.5):", dist.log_prob(t(10.5)))

Result:

Mean: tensor(10.0981)
Variance: tensor(0.0094)
Log-prob(10.5): tensor(-2.8126)

Parallel vs Sequential Implementation

The class obtained by doing from stable_trunc_gaussian import TruncatedGaussian corresponds to a parallel implementation of the truncated gaussian, which makes possible to obtain several values (mean, variance and log-probs) in parallel. In case you are only interested in computing values sequentially, i.e., one at a time, we also provide a sequential implementation which results more efficient only for this case. In order to use this sequential implementation, simply do from stable_trunc_gaussian import SeqTruncatedGaussian. Here is an example:

from stable_trunc_gaussian import TruncatedGaussian, SeqTruncatedGaussian
from torch import tensor as t

# Parallel computation
means = TruncatedGaussian(t([0,0.5]),t(1,1),t(-1,2),t(1,5)).mean

# Sequential computation
# Note: the 'TruncatedGaussian' class can also be used for this sequential case
mean_0 = SeqTruncatedGaussian(t([0]),t(1),t(-1),t(1)).mean
mean_1 = SeqTruncatedGaussian(t([0.5]),t(1),t(2),t(5)).mean

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

stable-trunc-gaussian-1.1.0.tar.gz (9.5 kB view details)

Uploaded Source

Built Distribution

stable_trunc_gaussian-1.1.0-py3-none-any.whl (11.3 kB view details)

Uploaded Python 3

File details

Details for the file stable-trunc-gaussian-1.1.0.tar.gz.

File metadata

  • Download URL: stable-trunc-gaussian-1.1.0.tar.gz
  • Upload date:
  • Size: 9.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.9

File hashes

Hashes for stable-trunc-gaussian-1.1.0.tar.gz
Algorithm Hash digest
SHA256 f77a68dbcd911d06daf83109068aa92144530a79c4cd5417b0c814b9ab8bd1b2
MD5 d18044f8898341ef54550e706d492247
BLAKE2b-256 b3578607abb9ab9a1d7a61f6da6dc6dd4c7c84985fee340846edb980b64bc30b

See more details on using hashes here.

File details

Details for the file stable_trunc_gaussian-1.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for stable_trunc_gaussian-1.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 6e1e4f6a3fb71c46cd3a6cc72e2045c963ac3948024a7c088c95d534abb0475a
MD5 62984ec1e287a9c11414c1c3bf8c21b8
BLAKE2b-256 acacec658b6ef71f11825c693789dae7faacee434dc23e2c09f48416f8f04621

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page