A numerically-stable and differentiable implementation of the Truncated Gaussian distribution in Pytorch.
Project description
Stable Truncated Gaussian
A differentiable implementation of the Truncated Gaussian (Normal) distribution using Python and Pytorch, which is numerically stable even when the μ parameter lies outside the interval [a,b] given by the bounds of the distribution. In this situation, a naive evaluation of the mean, variance and log-probability of the distribution could otherwise result in catastrophic cancellation. Our code is inspired by TruncatedNormal.jl and torch_truncnorm. Currently, we only provide functionality for calculating the mean, variance and log-probability, but not for calculating the entropy or sampling from the distribution.
Installation
Simply install with pip
:
pip install stable-trunc-gaussian
Example
Run the following code in Python:
from stable_trunc_gaussian import TruncatedGaussian as TG
from torch import tensor as t
# Create a Truncated Gaussian with mu=0, sigma=1, a=10, b=11
# Notice how mu is outside the interval [a,b]
dist = TG(t(0),t(1),t(10),t(11))
print("Mean:", dist.mean)
print("Variance:", dist.variance)
print("Log-prob(10.5):", dist.log_prob(t(10.5)))
Result:
Mean: tensor(10.0981)
Variance: tensor(0.0094)
Log-prob(10.5): tensor(-2.8126)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file stable-trunc-gaussian-1.0.0.tar.gz
.
File metadata
- Download URL: stable-trunc-gaussian-1.0.0.tar.gz
- Upload date:
- Size: 7.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6f3dbef9e7a1b5145744b1a4152f07c64ad2829edaa31943717c5500df7ef9aa |
|
MD5 | 75dbd93654c55262435b9ad5ccaeb52f |
|
BLAKE2b-256 | 41f18cc932f70d91b38d3338bb81016b4b82a9fabac13e833928031d1073c37e |
File details
Details for the file stable_trunc_gaussian-1.0.0-py3-none-any.whl
.
File metadata
- Download URL: stable_trunc_gaussian-1.0.0-py3-none-any.whl
- Upload date:
- Size: 7.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e23935b2f44720a55e8f049efe41e84ec8cf8a569de6fdd44be3dbe311a8e5b3 |
|
MD5 | e8e39064451127f764e9bbee1d93c8a4 |
|
BLAKE2b-256 | 7603346cdaa42c4585fb9b59056f3b9bb390d6089eb347648cd92dbc0e35bcde |