Skip to main content

A numerically-stable and differentiable implementation of the Truncated Gaussian distribution in Pytorch.

Project description

PyPI version

Stable Truncated Gaussian

A differentiable implementation of the Truncated Gaussian (Normal) distribution using Python and Pytorch, which is numerically stable even when the μ parameter lies outside the interval [a,b] given by the bounds of the distribution. In this situation, a naive evaluation of the mean, variance and log-probability of the distribution could otherwise result in catastrophic cancellation. Our code is inspired by TruncatedNormal.jl and torch_truncnorm. Currently, we only provide functionality for calculating the mean, variance and log-probability, but not for calculating the entropy or sampling from the distribution.

Installation

Simply install with pip:

pip install stable-trunc-gaussian

Example

Run the following code in Python:

from stable_trunc_gaussian import TruncatedGaussian as TG
from torch import tensor as t

# Create a Truncated Gaussian with mu=0, sigma=1, a=10, b=11
# Notice how mu is outside the interval [a,b]
dist = TG(t(0),t(1),t(10),t(11))

print("Mean:", dist.mean)
print("Variance:", dist.variance)
print("Log-prob(10.5):", dist.log_prob(t(10.5)))

Result:

Mean: tensor(10.0981)
Variance: tensor(0.0094)
Log-prob(10.5): tensor(-2.8126)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

stable-trunc-gaussian-1.0.1.tar.gz (7.3 kB view details)

Uploaded Source

Built Distribution

stable_trunc_gaussian-1.0.1-py3-none-any.whl (7.8 kB view details)

Uploaded Python 3

File details

Details for the file stable-trunc-gaussian-1.0.1.tar.gz.

File metadata

  • Download URL: stable-trunc-gaussian-1.0.1.tar.gz
  • Upload date:
  • Size: 7.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.9

File hashes

Hashes for stable-trunc-gaussian-1.0.1.tar.gz
Algorithm Hash digest
SHA256 2923ce77a2b89609e321c17bc34608386d43a005e9d9bc4e3401d9e11cad6aa4
MD5 a98d415bbd1501b9a53a04726f759c7a
BLAKE2b-256 23549b0fb5ab4323afff572079da4100c553dbefff3164351ce7b7789464c03f

See more details on using hashes here.

File details

Details for the file stable_trunc_gaussian-1.0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for stable_trunc_gaussian-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 c900d7d00750de770df932c5ed3027f3fc10dbdd4a5ff9d48b09f6923f874c9c
MD5 7f8b8c6ac9900bf47e72a5729b94547f
BLAKE2b-256 4979a841130ab61be4abd8aa326281615fc0ef68b89d595fdcbd9c7ba305bfbc

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page