Skip to main content

A Simple Statistical Distribution Library In JAX

Project description

FenbuX

A Simple Probalistic Distribution Library in JAX

fenbu (分布, pronounce like: /fen'bu:/)-X is a simple probalistic distribution library in JAX. In fenbux, We provide you:

  • A simple and easy-to-use interface like Distributions.jl
  • Bijectors like TensorFlow Probability and Bijector.jl
  • PyTree input/output
  • Multiple dispatch for different distributions based on plum-dispatch
  • All jax feautures (vmap, pmap, jit, autograd etc.)

See document

Examples

Statistics of Distributions 🤔

import jax.numpy as jnp
from fenbux import variance, skewness, mean
from fenbux.univariate import Normal

μ = {'a': jnp.array([1., 2., 3.]), 'b': jnp.array([4., 5., 6.])} 
σ = {'a': jnp.array([4., 5., 6.]), 'b': jnp.array([7., 8., 9.])}

dist = Normal(μ, σ)
mean(dist) # {'a': Array([1., 2., 3.], dtype=float32), 'b': Array([4., 5., 6.], dtype=float32)}
variance(dist) # {'a': Array([16., 25., 36.], dtype=float32), 'b': Array([49., 64., 81.], dtype=float32)}
skewness(dist) # {'a': Array([0., 0., 0.], dtype=float32), 'b': Array([0., 0., 0.], dtype=float32)}

Random Variables Generation

import jax.random as jr
from fenbux import rand
from fenbux.univariate import Normal


key =  jr.PRNGKey(0)
x = {'a': {'c': {'d': {'e': 1.}}}}
y = {'a': {'c': {'d': {'e': 1.}}}}

dist = Normal(x, y)
rand(dist, key, shape=(3, )) # {'a': {'c': {'d': {'e': Array([1.6248107 , 0.69599575, 0.10169095], dtype=float32)}}}}

Evaluations of Distribution 👩‍🎓

CDF, PDF, and more...

import jax.numpy as jnp
from fenbux import cdf, logpdf
from fenbux.univariate import Normal


μ = jnp.array([1., 2., 3.])
σ = jnp.array([4., 5., 6.])

dist = Normal(μ, σ)
cdf(dist, jnp.array([1., 2., 3.])) # Array([0.5, 0.5, 0.5], dtype=float32)
logpdf(dist, jnp.array([1., 2., 3.])) # Array([-2.305233 , -2.5283763, -2.7106981], dtype=float32)

Nested Transformations of Distribution 🤖

import fenbux as fbx
import jax.numpy as jnp
from fenbux.univariate import Normal

# truncate and censor and affine
d = Normal(0, 1)
fbx.affine(fbx.censor(fbx.truncate(d, 0, 1), 0, 1), 0, 1)
fbx.logpdf(d, 0.5)
Array(-1.0439385, dtype=float32)

Compatible with JAX transformations 😃

  • vmap
import jax.numpy as jnp
from jax import vmap

from fenbux import logpdf
from fenbux.univariate import Normal


dist = Normal({'a': jnp.zeros((2, 3))}, {'a':jnp.ones((2, 3, 5))}) # each batch shape is (2, 3)
x = jnp.zeros((2, 3, 5))
# claim use_batch=True to use vmap
vmap(logpdf, in_axes=(Normal(None, {'a': 2}, use_batch=True), 2))(dist, x) 
  • grad
import jax.numpy as jnp
from jax import jit, grad
from fenbux import logpdf
from fenbux.univariate import Normal

dist = Normal(0., 1.)
grad(logpdf)(dist, 0.)

Bijectors 🧙‍♂️

Evaluate a bijector

import jax.numpy as jnp
from fenbux.bijector import Exp, evaluate

bij = Exp()
x = jnp.array([1., 2., 3.])

evaluate(bij, x)

Apply a bijector to a distribution

import jax.numpy as jnp
from fenbux.bijector import Exp, transform
from fenbux.univariate import Normal
from fenbux import logpdf

dist = Normal(0, 1)
bij = Exp()

log_normal = transform(dist, bij)

x = jnp.array([1., 2., 3.])
logpdf(log_normal, x)

Speed 🔦

  • Common Evaluations
import numpy as np
from scipy.stats import norm
from jax import jit
from fenbux import logpdf, rand
from fenbux.univariate import Normal
from tensorflow_probability.substrates.jax.distributions import Normal as Normal2

dist = Normal(0, 1)
dist2 = Normal2(0, 1)
dist3 = norm(0, 1)
x = np.random.normal(size=100000)

%timeit jit(logpdf)(dist, x).block_until_ready()
%timeit jit(dist2.log_prob)(x).block_until_ready()
%timeit dist3.logpdf(x)
51.2 µs ± 1.47 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
11.1 ms ± 176 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.12 ms ± 20.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
  • Evaluations with Bijector Transformed Distributions
import jax.numpy as jnp
import numpy as np
import tensorflow_probability.substrates.jax.bijectors as tfb
import tensorflow_probability.substrates.jax.distributions as tfd
from jax import jit

from fenbux import logpdf
from fenbux.bijector import Exp, transform
from fenbux.univariate import Normal


x = jnp.asarray(np.random.uniform(size=100000))
dist = Normal(0, 1)
bij = Exp()
log_normal = transform(dist, bij)

dist2 = tfd.Normal(loc=0, scale=1)
bij2 = tfb.Exp()
log_normal2 = tfd.TransformedDistribution(dist2, bij2)

def log_prob(d, x):
    return d.log_prob(x)

%timeit jit(logpdf)(log_normal, x).block_until_ready()
%timeit jit(log_prob)(log_normal2, x).block_until_ready()
131 µs ± 514 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
375 µs ± 10.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Installation

  • Install on your local device.
git clone https://github.com/JiaYaobo/fenbux.git
pip install -e .
  • Install from PyPI.
pip install -U fenbux

Reference

Citation

@software{fenbux,
  author = {Jia, Yaobo},
  title = {fenbux: A Simple Probalistic Distribution Library in JAX},
  url = {https://github.com/JiaYaobo/fenbux},
  year = {2024}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fenbux-0.1.0.tar.gz (48.1 kB view details)

Uploaded Source

Built Distribution

fenbux-0.1.0-py3-none-any.whl (76.5 kB view details)

Uploaded Python 3

File details

Details for the file fenbux-0.1.0.tar.gz.

File metadata

  • Download URL: fenbux-0.1.0.tar.gz
  • Upload date:
  • Size: 48.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.3

File hashes

Hashes for fenbux-0.1.0.tar.gz
Algorithm Hash digest
SHA256 ecb4f6c38d3957fa7109e85582b1ae1fb68292b5433b368b9be6b9fb213e532f
MD5 eebba5d59507efd0a419702fe440e600
BLAKE2b-256 ced0224d3d81f8c77c2b12b21c183c4772bd605a4741a0ba3d90417bdf1c9185

See more details on using hashes here.

File details

Details for the file fenbux-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: fenbux-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 76.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.3

File hashes

Hashes for fenbux-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 0981043e3e0ea97c038efb4d1545e3f7cf92b0b3806bf427085b70fc0a6dba5b
MD5 45180d7eba11b617cef680943b8c44c9
BLAKE2b-256 21da2cff213a8135f781299e94211a93084a3108f8bfe7ecfea6b4d77ec8fce6

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page