A Simple Statistical Distribution Library In JAX
Project description
FenbuX
A Simple Probalistic Distribution Library in JAX
fenbu (分布, pronounce like: /fen'bu:/)-X is a simple probalistic distribution library in JAX. In fenbux, We provide you:
- A simple and easy-to-use interface like Distributions.jl
- Bijectors like TensorFlow Probability and Bijector.jl
- PyTree input/output
- Multiple dispatch for different distributions based on plum-dispatch
- All jax feautures (vmap, pmap, jit, autograd etc.)
See document
Examples
Statistics of Distributions 🤔
import jax.numpy as jnp
from fenbux import variance, skewness, mean
from fenbux.univariate import Normal
μ = {'a': jnp.array([1., 2., 3.]), 'b': jnp.array([4., 5., 6.])}
σ = {'a': jnp.array([4., 5., 6.]), 'b': jnp.array([7., 8., 9.])}
dist = Normal(μ, σ)
mean(dist) # {'a': Array([1., 2., 3.], dtype=float32), 'b': Array([4., 5., 6.], dtype=float32)}
variance(dist) # {'a': Array([16., 25., 36.], dtype=float32), 'b': Array([49., 64., 81.], dtype=float32)}
skewness(dist) # {'a': Array([0., 0., 0.], dtype=float32), 'b': Array([0., 0., 0.], dtype=float32)}
Random Variables Generation
import jax.random as jr
from fenbux import rand
from fenbux.univariate import Normal
key = jr.PRNGKey(0)
x = {'a': {'c': {'d': {'e': 1.}}}}
y = {'a': {'c': {'d': {'e': 1.}}}}
dist = Normal(x, y)
rand(dist, key, shape=(3, )) # {'a': {'c': {'d': {'e': Array([1.6248107 , 0.69599575, 0.10169095], dtype=float32)}}}}
Evaluations of Distribution 👩🎓
CDF, PDF, and more...
import jax.numpy as jnp
from fenbux import cdf, logpdf
from fenbux.univariate import Normal
μ = jnp.array([1., 2., 3.])
σ = jnp.array([4., 5., 6.])
dist = Normal(μ, σ)
cdf(dist, jnp.array([1., 2., 3.])) # Array([0.5, 0.5, 0.5], dtype=float32)
logpdf(dist, jnp.array([1., 2., 3.])) # Array([-2.305233 , -2.5283763, -2.7106981], dtype=float32)
Nested Transformations of Distribution 🤖
import fenbux as fbx
import jax.numpy as jnp
from fenbux.univariate import Normal
# truncate and censor and affine
d = Normal(0, 1)
fbx.affine(fbx.censor(fbx.truncate(d, 0, 1), 0, 1), 0, 1)
fbx.logpdf(d, 0.5)
Array(-1.0439385, dtype=float32)
Compatible with JAX transformations 😃
- vmap
import jax.numpy as jnp
from jax import vmap
from fenbux import logpdf
from fenbux.univariate import Normal
dist = Normal({'a': jnp.zeros((2, 3))}, {'a':jnp.ones((2, 3, 5))}) # each batch shape is (2, 3)
x = jnp.zeros((2, 3, 5))
# claim use_batch=True to use vmap
vmap(logpdf, in_axes=(Normal(None, {'a': 2}, use_batch=True), 2))(dist, x)
- grad
import jax.numpy as jnp
from jax import jit, grad
from fenbux import logpdf
from fenbux.univariate import Normal
dist = Normal(0., 1.)
grad(logpdf)(dist, 0.)
Bijectors 🧙♂️
Evaluate a bijector
import jax.numpy as jnp
from fenbux.bijector import Exp, evaluate
bij = Exp()
x = jnp.array([1., 2., 3.])
evaluate(bij, x)
Apply a bijector to a distribution
import jax.numpy as jnp
from fenbux.bijector import Exp, transform
from fenbux.univariate import Normal
from fenbux import logpdf
dist = Normal(0, 1)
bij = Exp()
log_normal = transform(dist, bij)
x = jnp.array([1., 2., 3.])
logpdf(log_normal, x)
Speed 🔦
- Common Evaluations
import numpy as np
from scipy.stats import norm
from jax import jit
from fenbux import logpdf, rand
from fenbux.univariate import Normal
from tensorflow_probability.substrates.jax.distributions import Normal as Normal2
dist = Normal(0, 1)
dist2 = Normal2(0, 1)
dist3 = norm(0, 1)
x = np.random.normal(size=100000)
%timeit jit(logpdf)(dist, x).block_until_ready()
%timeit jit(dist2.log_prob)(x).block_until_ready()
%timeit dist3.logpdf(x)
51.2 µs ± 1.47 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
11.1 ms ± 176 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.12 ms ± 20.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
- Evaluations with Bijector Transformed Distributions
import jax.numpy as jnp
import numpy as np
import tensorflow_probability.substrates.jax.bijectors as tfb
import tensorflow_probability.substrates.jax.distributions as tfd
from jax import jit
from fenbux import logpdf
from fenbux.bijector import Exp, transform
from fenbux.univariate import Normal
x = jnp.asarray(np.random.uniform(size=100000))
dist = Normal(0, 1)
bij = Exp()
log_normal = transform(dist, bij)
dist2 = tfd.Normal(loc=0, scale=1)
bij2 = tfb.Exp()
log_normal2 = tfd.TransformedDistribution(dist2, bij2)
def log_prob(d, x):
return d.log_prob(x)
%timeit jit(logpdf)(log_normal, x).block_until_ready()
%timeit jit(log_prob)(log_normal2, x).block_until_ready()
131 µs ± 514 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
375 µs ± 10.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Installation
- Install on your local device.
git clone https://github.com/JiaYaobo/fenbux.git
pip install -e .
- Install from PyPI.
pip install -U fenbux
Reference
Citation
@software{fenbux,
author = {Jia, Yaobo},
title = {fenbux: A Simple Probalistic Distribution Library in JAX},
url = {https://github.com/JiaYaobo/fenbux},
year = {2024}
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
fenbux-0.1.0.tar.gz
(48.1 kB
view details)
Built Distribution
fenbux-0.1.0-py3-none-any.whl
(76.5 kB
view details)
File details
Details for the file fenbux-0.1.0.tar.gz
.
File metadata
- Download URL: fenbux-0.1.0.tar.gz
- Upload date:
- Size: 48.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ecb4f6c38d3957fa7109e85582b1ae1fb68292b5433b368b9be6b9fb213e532f |
|
MD5 | eebba5d59507efd0a419702fe440e600 |
|
BLAKE2b-256 | ced0224d3d81f8c77c2b12b21c183c4772bd605a4741a0ba3d90417bdf1c9185 |
File details
Details for the file fenbux-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: fenbux-0.1.0-py3-none-any.whl
- Upload date:
- Size: 76.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0981043e3e0ea97c038efb4d1545e3f7cf92b0b3806bf427085b70fc0a6dba5b |
|
MD5 | 45180d7eba11b617cef680943b8c44c9 |
|
BLAKE2b-256 | 21da2cff213a8135f781299e94211a93084a3108f8bfe7ecfea6b4d77ec8fce6 |