Skip to main content

JAX package for Generalized Hyperbolic distributions as exponential families

Project description

normix

JAX package for Generalized Hyperbolic distributions as exponential families.

Built on Equinox with Float64 precision throughout.

Installation

pip install normix

Install optional plotting helpers with:

pip install "normix[plotting]"

For local development:

uv sync
pip install -e .

Quick Start

import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

from normix import GeneralizedHyperbolic
from normix.fitting.em import BatchEMFitter

# Fit GH distribution to data via EM
key = jax.random.PRNGKey(0)
X = jax.random.normal(key, (1000, 3))

model = GeneralizedHyperbolic.from_classical(
    mu=jnp.zeros(3), gamma=jnp.zeros(3),
    sigma=jnp.eye(3), p=-0.5, a=2.0, b=1.0,
)
result = BatchEMFitter(max_iter=100).fit(model, X)

# Evaluate log-density (batched via vmap)
log_p = jax.vmap(result.model.log_prob)(X)   # shape (1000,)

Distributions

Univariate (exponential family)

Class Parameters Description
Gamma alpha, beta Shape α > 0, rate β > 0
InverseGamma alpha, beta Shape α > 0, rate β > 0
InverseGaussian mu, lam Mean μ > 0, shape λ > 0
GIG / GeneralizedInverseGaussian p, a, b Generalized Inverse Gaussian

Multivariate

Class Parameters Description
MultivariateNormal mu, L_Sigma Mean μ, Cholesky L_Sigma of Σ

Normal Variance-Mean Mixtures (marginal)

Class Subordinator Parameters
VarianceGamma Gamma mu, gamma, L_Sigma, alpha, beta
NormalInverseGamma InverseGamma mu, gamma, L_Sigma, alpha, beta
NormalInverseGaussian InverseGaussian mu, gamma, L_Sigma, mu_ig, lam
GeneralizedHyperbolic GIG mu, gamma, L_Sigma, p, a, b

Joint distributions

The Joint* classes (e.g. JointGeneralizedHyperbolic) model the full joint $f(x,y)$ where Y is the mixing variable. They are exponential families and are used internally for the EM E-step.

Exponential Family API

All univariate and joint distributions subclass ExponentialFamily(eqx.Module):

import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

from normix import Gamma

X = jnp.array([1.0, 1.5, 2.0, 2.5])
dist = Gamma(alpha=jnp.array(2.0), beta=jnp.array(1.0))

# Log-density (single observation)
dist.log_prob(jnp.array(1.5))

# Three parametrizations
theta = dist.natural_params()       # natural parameters θ
eta   = dist.expectation_params()   # expectation parameters η = E[t(X)]
I     = dist.fisher_information()   # Fisher information I(θ) = ∇²ψ(θ)

# Constructors
dist2 = Gamma.from_natural(theta)
dist3 = Gamma.from_expectation(eta)
dist4 = Gamma.fit_mle(X)           # η̂ = mean t(xᵢ)

EM Algorithm

import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

from normix import GeneralizedHyperbolic
from normix.fitting.em import BatchEMFitter

d = 3
X = ...  # (n, d) data array

# Initialise from classical parameters
model = GeneralizedHyperbolic.from_classical(
    mu=jnp.zeros(d), gamma=jnp.zeros(d), sigma=jnp.eye(d),
    p=-0.5, a=2.0, b=1.0,
)

# Fit with hybrid CPU/JAX backend for maximum speed
fitter = BatchEMFitter(max_iter=200, tol=1e-6,
                       e_step_backend='cpu', m_step_backend='cpu')
result = fitter.fit(model, X)
fitted = result.model

Bessel Functions

import jax

jax.config.update("jax_enable_x64", True)

from normix import log_kv        # or: from normix.utils.bessel import log_kv

# JIT-able, differentiable (backend='jax', default)
log_kv(v=0.5, z=2.0)

# Fast CPU path for EM hot path (not JIT-able)
log_kv(v=0.5, z=2.0, backend='cpu')

Package Layout

normix/
├── exponential_family.py         # ExponentialFamily base class
├── distributions/                # All distribution implementations
│   ├── gamma.py
│   ├── inverse_gamma.py
│   ├── inverse_gaussian.py
│   ├── generalized_inverse_gaussian.py
│   ├── normal.py
│   ├── variance_gamma.py
│   ├── normal_inverse_gamma.py
│   ├── normal_inverse_gaussian.py
│   └── generalized_hyperbolic.py
├── mixtures/                     # Joint and marginal base classes
├── fitting/em.py                 # BatchEMFitter, EMResult
└── utils/
    ├── bessel.py                 # log_kv with custom JVP
    ├── constants.py              # Shared numerical constants
    ├── plotting.py               # Notebook helpers
    └── validation.py             # EM validation helpers

Development

uv run pytest tests/              # run tests
uv run jupyter lab                # notebooks
make -C docs html                 # build docs

References

  • Barndorff-Nielsen, O. E. (1977). Exponentially decreasing distributions for the logarithm of particle size.
  • Eberlein, E., & Keller, U. (1995). Hyperbolic distributions in finance.

License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

normix-0.2.0.tar.gz (57.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

normix-0.2.0-py3-none-any.whl (74.7 kB view details)

Uploaded Python 3

File details

Details for the file normix-0.2.0.tar.gz.

File metadata

  • Download URL: normix-0.2.0.tar.gz
  • Upload date:
  • Size: 57.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for normix-0.2.0.tar.gz
Algorithm Hash digest
SHA256 ecbccd5117f0336e92c50c021f29e76a5c17d1ecd71d89b9e363d4df435d3603
MD5 933681d576eb2e3437f7eb3ecd81f17a
BLAKE2b-256 cd9a9df3cffb6e1fd2827c7839fd30e88f29bb39fffe12a21a3147c7af8a6e48

See more details on using hashes here.

Provenance

The following attestation bundles were made for normix-0.2.0.tar.gz:

Publisher: publish.yml on xshi19/normix

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file normix-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: normix-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 74.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for normix-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 4d54635468ffbc5af750d5360ba40379618c98716cc14bd2aa89853e05042131
MD5 c848b4df7199cb2d77331fe31b4ce006
BLAKE2b-256 f566783a00dfabdeaac011b1dbde7a4d38aee81ee65a06c4c25f2ec8fcb5aa14

See more details on using hashes here.

Provenance

The following attestation bundles were made for normix-0.2.0-py3-none-any.whl:

Publisher: publish.yml on xshi19/normix

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page