A Simple Statistical Distribution Library In JAX
Project description
FenbuX
A Simple Probalistic Distribution Library in JAX
fenbu (分布, pronounce like: /fen'bu:/)-X is a simple probalistic distribution library in JAX. In fenbux, We provide you:
- A simple and easy-to-use interface like Distributions.jl
- Bijectors like TensorFlow Probability and Bijector.jl
- PyTree input/output
- Multiple dispatch for different distributions based on plum-dispatch
- All jax feautures (vmap, pmap, jit, autograd etc.)
See document
Examples
Statistics of Distributions 🤔
import jax.numpy as jnp
from fenbux import variance, skewness, mean
from fenbux.univariate import Normal
μ = {'a': jnp.array([1., 2., 3.]), 'b': jnp.array([4., 5., 6.])}
σ = {'a': jnp.array([4., 5., 6.]), 'b': jnp.array([7., 8., 9.])}
dist = Normal(μ, σ)
mean(dist) # {'a': Array([1., 2., 3.], dtype=float32), 'b': Array([4., 5., 6.], dtype=float32)}
variance(dist) # {'a': Array([16., 25., 36.], dtype=float32), 'b': Array([49., 64., 81.], dtype=float32)}
skewness(dist) # {'a': Array([0., 0., 0.], dtype=float32), 'b': Array([0., 0., 0.], dtype=float32)}
Random Variables Generation
import jax.random as jr
from fenbux import rand
from fenbux.univariate import Normal
key = jr.PRNGKey(0)
x = {'a': {'c': {'d': {'e': 1.}}}}
y = {'a': {'c': {'d': {'e': 1.}}}}
dist = Normal(x, y)
rand(dist, key, shape=(3, )) # {'a': {'c': {'d': {'e': Array([1.6248107 , 0.69599575, 0.10169095], dtype=float32)}}}}
Evaluations of Distribution 👩🎓
CDF, PDF, and more...
import jax.numpy as jnp
from fenbux import cdf, logpdf
from fenbux.univariate import Normal
μ = jnp.array([1., 2., 3.])
σ = jnp.array([4., 5., 6.])
dist = Normal(μ, σ)
cdf(dist, jnp.array([1., 2., 3.])) # Array([0.5, 0.5, 0.5], dtype=float32)
logpdf(dist, jnp.array([1., 2., 3.])) # Array([-2.305233 , -2.5283763, -2.7106981], dtype=float32)
Nested Transformations of Distribution 🤖
import fenbux as fbx
import jax.numpy as jnp
from fenbux.univariate import Normal
# truncate and censor and affine
d = Normal(0, 1)
fbx.affine(fbx.censor(fbx.truncate(d, 0, 1), 0, 1), 0, 1)
fbx.logpdf(d, 0.5)
Array(-1.0439385, dtype=float32)
Compatible with JAX transformations 😃
- vmap
import jax.numpy as jnp
from jax import vmap
from fenbux import logpdf
from fenbux.univariate import Normal
dist = Normal({'a': jnp.zeros((2, 3))}, {'a':jnp.ones((2, 3, 5))}) # each batch shape is (2, 3)
x = jnp.zeros((2, 3, 5))
# claim use_batch=True to use vmap
vmap(logpdf, in_axes=(Normal(None, {'a': 2}, use_batch=True), 2))(dist, x)
- grad
import jax.numpy as jnp
from jax import jit, grad
from fenbux import logpdf
from fenbux.univariate import Normal
dist = Normal(0., 1.)
grad(logpdf)(dist, 0.)
Bijectors 🧙♂️
Evaluate a bijector
import jax.numpy as jnp
from fenbux.bijector import Exp, evaluate
bij = Exp()
x = jnp.array([1., 2., 3.])
evaluate(bij, x)
Apply a bijector to a distribution
import jax.numpy as jnp
from fenbux.bijector import Exp, transform
from fenbux.univariate import Normal
from fenbux import logpdf
dist = Normal(0, 1)
bij = Exp()
log_normal = transform(dist, bij)
x = jnp.array([1., 2., 3.])
logpdf(log_normal, x)
Speed 🔦
- Common Evaluations
import numpy as np
from scipy.stats import norm
from jax import jit
from fenbux import logpdf, rand
from fenbux.univariate import Normal
from tensorflow_probability.substrates.jax.distributions import Normal as Normal2
dist = Normal(0, 1)
dist2 = Normal2(0, 1)
dist3 = norm(0, 1)
x = np.random.normal(size=100000)
%timeit jit(logpdf)(dist, x).block_until_ready()
%timeit jit(dist2.log_prob)(x).block_until_ready()
%timeit dist3.logpdf(x)
51.2 µs ± 1.47 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
11.1 ms ± 176 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.12 ms ± 20.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
- Evaluations with Bijector Transformed Distributions
import jax.numpy as jnp
import numpy as np
import tensorflow_probability.substrates.jax.bijectors as tfb
import tensorflow_probability.substrates.jax.distributions as tfd
from jax import jit
from fenbux import logpdf
from fenbux.bijector import Exp, transform
from fenbux.univariate import Normal
x = jnp.asarray(np.random.uniform(size=100000))
dist = Normal(0, 1)
bij = Exp()
log_normal = transform(dist, bij)
dist2 = tfd.Normal(loc=0, scale=1)
bij2 = tfb.Exp()
log_normal2 = tfd.TransformedDistribution(dist2, bij2)
def log_prob(d, x):
return d.log_prob(x)
%timeit jit(logpdf)(log_normal, x).block_until_ready()
%timeit jit(log_prob)(log_normal2, x).block_until_ready()
131 µs ± 514 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
375 µs ± 10.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Installation
- Install on your local device.
git clone https://github.com/JiaYaobo/fenbux.git
pip install -e .
- Install from PyPI.
pip install -U fenbux
Reference
Citation
@software{fenbux,
author = {Jia, Yaobo},
title = {fenbux: A Simple Probalistic Distribution Library in JAX},
url = {https://github.com/JiaYaobo/fenbux},
year = {2024}
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file fenbux-0.1.0.tar.gz.
File metadata
- Download URL: fenbux-0.1.0.tar.gz
- Upload date:
- Size: 48.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ecb4f6c38d3957fa7109e85582b1ae1fb68292b5433b368b9be6b9fb213e532f
|
|
| MD5 |
eebba5d59507efd0a419702fe440e600
|
|
| BLAKE2b-256 |
ced0224d3d81f8c77c2b12b21c183c4772bd605a4741a0ba3d90417bdf1c9185
|
File details
Details for the file fenbux-0.1.0-py3-none-any.whl.
File metadata
- Download URL: fenbux-0.1.0-py3-none-any.whl
- Upload date:
- Size: 76.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0981043e3e0ea97c038efb4d1545e3f7cf92b0b3806bf427085b70fc0a6dba5b
|
|
| MD5 |
45180d7eba11b617cef680943b8c44c9
|
|
| BLAKE2b-256 |
21da2cff213a8135f781299e94211a93084a3108f8bfe7ecfea6b4d77ec8fce6
|