Skip to main content

A Simple Pytree Based Statistical Distribution Library In JAX

Project description

FenbuX

A Simple Probalistic Distribution Library in JAX

fenbu (分布, pronounce like: /fen'bu:/)-X is a simple probalistic distribution library in JAX. The library is encouraged by Distributions.jl. In fenbux, We provide you:

  • A simple and easy-to-use interface like Distributions.jl
  • PyTree input/output
  • Multiple dispatch for different distributions based on plum-dispatch
  • All jax feautures (vmap, pmap, jit, autograd etc.)

Examples

  • Extract Attributes of Distributions 🤔
import jax.numpy as jnp
from fenbux import Normal, variance, skewness, mean

μ = {'a': jnp.array([1., 2., 3.]), 'b': jnp.array([4., 5., 6.])} 
σ = {'a': jnp.array([4., 5., 6.]), 'b': jnp.array([7., 8., 9.])}

dist = Normal(μ, σ)
mean(dist) # {'a': Array([1., 2., 3.], dtype=float32), 'b': Array([4., 5., 6.], dtype=float32)}
variance(dist) # {'a': Array([16., 25., 36.], dtype=float32), 'b': Array([49., 64., 81.], dtype=float32)}
skewness(dist) # {'a': Array([0., 0., 0.], dtype=float32), 'b': Array([0., 0., 0.], dtype=float32)}
  • Random Variables Generation
import jax.random as jr
from fenbux import Normal, rand

key =  jr.PRNGKey(0)
x = {'a': {'c': {'d': {'e': 1.}}}}
y = {'a': {'c': {'d': {'e': 1.}}}}

dist = Normal(x, y)
rand(dist, key, shape=(3, )) # {'a': {'c': {'d': {'e': Array([1.6248107 , 0.69599575, 0.10169095], dtype=float32)}}}}
  • Functions of Distribution 👩‍🎓

CDF, PDF, and more...

import jax.numpy as jnp
from fenbux import Normal, cdf, logpdf

μ = jnp.array([1., 2., 3.])
σ = jnp.array([4., 5., 6.])

dist = Normal(μ, σ)
cdf(dist, jnp.array([1., 2., 3.])) # Array([0.5, 0.5, 0.5], dtype=float32)
logpdf(dist, jnp.array([1., 2., 3.])) # Array([-2.305233 , -2.5283763, -2.7106981], dtype=float32)
  • Compatible with JAX transformations 😃
import jax.numpy as jnp
from jax import jit, vmap
from fenbux import Normal, logpdf

dist = Normal(0, jnp.ones((3, )))
# set claim use_batch=True to use vmap
vmap(jit(logpdf), in_axes=(Normal(None, 0, use_batch=True), 0))(dist, jnp.zeros((3, )))
  • Speed 🔦
import jax.numpy as jnp
from scipy.stats import norm
from jax import jit
from fenbux import Normal, logpdf
from tensorflow_probability.substrates.jax.distributions import Normal as Normal2

dist = Normal(0, 1)
dist2 = Normal2(0, 1)
dist3 = norm(0, 1)
x = jnp.linspace(-5, 5, 100000)

%timeit jit(logpdf)(dist, x).block_until_ready()
%timeit jit(dist2.log_prob)(x).block_until_ready()
%timeit dist3.logpdf(x)
34.4 µs ± 678 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
9.64 ms ± 177 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.17 ms ± 51.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Installation

git clone https://github.com/JiaYaobo/fenbux.git
pip install -e .

Reference

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fenbux-0.0.1.tar.gz (39.6 kB view details)

Uploaded Source

Built Distribution

fenbux-0.0.1-py3-none-any.whl (56.8 kB view details)

Uploaded Python 3

File details

Details for the file fenbux-0.0.1.tar.gz.

File metadata

  • Download URL: fenbux-0.0.1.tar.gz
  • Upload date:
  • Size: 39.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.3

File hashes

Hashes for fenbux-0.0.1.tar.gz
Algorithm Hash digest
SHA256 21f1a77e2706e6eeca70c72a0dc077d951b68e1c6b4dde61c2cb02728f24b157
MD5 9559bd714f0f804acd0ca9fd1a022c87
BLAKE2b-256 5a961008bdbd5743cdc2d78337cc1f0b8b07401e8b31e333898d52f1c98f4f85

See more details on using hashes here.

File details

Details for the file fenbux-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: fenbux-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 56.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.3

File hashes

Hashes for fenbux-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 4647dc084b77bd17a8d9582a7d6992d02f461953d264d113db284172987930d1
MD5 8bc95d6b2965de5548a9c40828a6dec5
BLAKE2b-256 df8752ed1e5b7e45384c1714ffbde9ccfe146a57fffefa0dee0439368eb212dd

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page