A Simple Pytree Based Statistical Distribution Library In JAX
Project description
FenbuX
A Simple Probalistic Distribution Library in JAX
fenbu (分布, pronounce like: /fen'bu:/)-X is a simple probalistic distribution library in JAX. The library is encouraged by Distributions.jl. In fenbux, We provide you:
- A simple and easy-to-use interface like Distributions.jl
- PyTree input/output
- Multiple dispatch for different distributions based on plum-dispatch
- All jax feautures (vmap, pmap, jit, autograd etc.)
Examples
- Extract Attributes of Distributions 🤔
import jax.numpy as jnp
from fenbux import Normal, variance, skewness, mean
μ = {'a': jnp.array([1., 2., 3.]), 'b': jnp.array([4., 5., 6.])}
σ = {'a': jnp.array([4., 5., 6.]), 'b': jnp.array([7., 8., 9.])}
dist = Normal(μ, σ)
mean(dist) # {'a': Array([1., 2., 3.], dtype=float32), 'b': Array([4., 5., 6.], dtype=float32)}
variance(dist) # {'a': Array([16., 25., 36.], dtype=float32), 'b': Array([49., 64., 81.], dtype=float32)}
skewness(dist) # {'a': Array([0., 0., 0.], dtype=float32), 'b': Array([0., 0., 0.], dtype=float32)}
- Random Variables Generation
import jax.random as jr
from fenbux import Normal, rand
key = jr.PRNGKey(0)
x = {'a': {'c': {'d': {'e': 1.}}}}
y = {'a': {'c': {'d': {'e': 1.}}}}
dist = Normal(x, y)
rand(dist, key, shape=(3, )) # {'a': {'c': {'d': {'e': Array([1.6248107 , 0.69599575, 0.10169095], dtype=float32)}}}}
- Functions of Distribution 👩🎓
CDF, PDF, and more...
import jax.numpy as jnp
from fenbux import Normal, cdf, logpdf
μ = jnp.array([1., 2., 3.])
σ = jnp.array([4., 5., 6.])
dist = Normal(μ, σ)
cdf(dist, jnp.array([1., 2., 3.])) # Array([0.5, 0.5, 0.5], dtype=float32)
logpdf(dist, jnp.array([1., 2., 3.])) # Array([-2.305233 , -2.5283763, -2.7106981], dtype=float32)
- Compatible with JAX transformations 😃
import jax.numpy as jnp
from jax import jit, vmap
from fenbux import Normal, logpdf
dist = Normal(0, jnp.ones((3, )))
# set claim use_batch=True to use vmap
vmap(jit(logpdf), in_axes=(Normal(None, 0, use_batch=True), 0))(dist, jnp.zeros((3, )))
- Speed 🔦
import jax.numpy as jnp
from scipy.stats import norm
from jax import jit
from fenbux import Normal, logpdf
from tensorflow_probability.substrates.jax.distributions import Normal as Normal2
dist = Normal(0, 1)
dist2 = Normal2(0, 1)
dist3 = norm(0, 1)
x = jnp.linspace(-5, 5, 100000)
%timeit jit(logpdf)(dist, x).block_until_ready()
%timeit jit(dist2.log_prob)(x).block_until_ready()
%timeit dist3.logpdf(x)
34.4 µs ± 678 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
9.64 ms ± 177 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.17 ms ± 51.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Installation
git clone https://github.com/JiaYaobo/fenbux.git
pip install -e .
Reference
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
fenbux-0.0.1.tar.gz
(39.6 kB
view details)
Built Distribution
fenbux-0.0.1-py3-none-any.whl
(56.8 kB
view details)
File details
Details for the file fenbux-0.0.1.tar.gz
.
File metadata
- Download URL: fenbux-0.0.1.tar.gz
- Upload date:
- Size: 39.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 21f1a77e2706e6eeca70c72a0dc077d951b68e1c6b4dde61c2cb02728f24b157 |
|
MD5 | 9559bd714f0f804acd0ca9fd1a022c87 |
|
BLAKE2b-256 | 5a961008bdbd5743cdc2d78337cc1f0b8b07401e8b31e333898d52f1c98f4f85 |
File details
Details for the file fenbux-0.0.1-py3-none-any.whl
.
File metadata
- Download URL: fenbux-0.0.1-py3-none-any.whl
- Upload date:
- Size: 56.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4647dc084b77bd17a8d9582a7d6992d02f461953d264d113db284172987930d1 |
|
MD5 | 8bc95d6b2965de5548a9c40828a6dec5 |
|
BLAKE2b-256 | df8752ed1e5b7e45384c1714ffbde9ccfe146a57fffefa0dee0439368eb212dd |