JAX package for Generalized Hyperbolic distributions as exponential families
Project description
normix
JAX package for Generalized Hyperbolic distributions as exponential families.
Built on Equinox with Float64 precision throughout.
Installation
pip install normix
Install optional plotting helpers with:
pip install "normix[plotting]"
For local development:
uv sync
pip install -e .
Quick Start
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
from normix import GeneralizedHyperbolic
from normix.fitting.em import BatchEMFitter
# Fit GH distribution to data via EM
key = jax.random.PRNGKey(0)
X = jax.random.normal(key, (1000, 3))
model = GeneralizedHyperbolic.from_classical(
mu=jnp.zeros(3), gamma=jnp.zeros(3),
sigma=jnp.eye(3), p=-0.5, a=2.0, b=1.0,
)
result = BatchEMFitter(max_iter=100).fit(model, X)
# Evaluate log-density (batched via vmap)
log_p = jax.vmap(result.model.log_prob)(X) # shape (1000,)
Distributions
Univariate (exponential family)
| Class | Parameters | Description |
|---|---|---|
Gamma |
alpha, beta |
Shape α > 0, rate β > 0 |
InverseGamma |
alpha, beta |
Shape α > 0, rate β > 0 |
InverseGaussian |
mu, lam |
Mean μ > 0, shape λ > 0 |
GIG / GeneralizedInverseGaussian |
p, a, b |
Generalized Inverse Gaussian |
Multivariate
| Class | Parameters | Description |
|---|---|---|
MultivariateNormal |
mu, L_Sigma |
Mean μ, Cholesky L_Sigma of Σ |
Normal Variance-Mean Mixtures (marginal)
| Class | Subordinator | Parameters |
|---|---|---|
VarianceGamma |
Gamma | mu, gamma, L_Sigma, alpha, beta |
NormalInverseGamma |
InverseGamma | mu, gamma, L_Sigma, alpha, beta |
NormalInverseGaussian |
InverseGaussian | mu, gamma, L_Sigma, mu_ig, lam |
GeneralizedHyperbolic |
GIG | mu, gamma, L_Sigma, p, a, b |
Joint distributions
The Joint* classes (e.g. JointGeneralizedHyperbolic) model the full joint $f(x,y)$ where Y is the mixing variable. They are exponential families and are used internally for the EM E-step.
Exponential Family API
All univariate and joint distributions subclass ExponentialFamily(eqx.Module):
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
from normix import Gamma
X = jnp.array([1.0, 1.5, 2.0, 2.5])
dist = Gamma(alpha=jnp.array(2.0), beta=jnp.array(1.0))
# Log-density (single observation)
dist.log_prob(jnp.array(1.5))
# Three parametrizations
theta = dist.natural_params() # natural parameters θ
eta = dist.expectation_params() # expectation parameters η = E[t(X)]
I = dist.fisher_information() # Fisher information I(θ) = ∇²ψ(θ)
# Constructors
dist2 = Gamma.from_natural(theta)
dist3 = Gamma.from_expectation(eta)
dist4 = Gamma.fit_mle(X) # η̂ = mean t(xᵢ)
EM Algorithm
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
from normix import GeneralizedHyperbolic
from normix.fitting.em import BatchEMFitter
d = 3
X = ... # (n, d) data array
# Initialise from classical parameters
model = GeneralizedHyperbolic.from_classical(
mu=jnp.zeros(d), gamma=jnp.zeros(d), sigma=jnp.eye(d),
p=-0.5, a=2.0, b=1.0,
)
# Fit with hybrid CPU/JAX backend for maximum speed
fitter = BatchEMFitter(max_iter=200, tol=1e-6,
e_step_backend='cpu', m_step_backend='cpu')
result = fitter.fit(model, X)
fitted = result.model
Bessel Functions
import jax
jax.config.update("jax_enable_x64", True)
from normix import log_kv # or: from normix.utils.bessel import log_kv
# JIT-able, differentiable (backend='jax', default)
log_kv(v=0.5, z=2.0)
# Fast CPU path for EM hot path (not JIT-able)
log_kv(v=0.5, z=2.0, backend='cpu')
Package Layout
normix/
├── exponential_family.py # ExponentialFamily base class
├── distributions/ # All distribution implementations
│ ├── gamma.py
│ ├── inverse_gamma.py
│ ├── inverse_gaussian.py
│ ├── generalized_inverse_gaussian.py
│ ├── normal.py
│ ├── variance_gamma.py
│ ├── normal_inverse_gamma.py
│ ├── normal_inverse_gaussian.py
│ └── generalized_hyperbolic.py
├── mixtures/ # Joint and marginal base classes
├── fitting/em.py # BatchEMFitter, EMResult
└── utils/
├── bessel.py # log_kv with custom JVP
├── constants.py # Shared numerical constants
├── plotting.py # Notebook helpers
└── validation.py # EM validation helpers
Development
uv run pytest tests/ # run tests
uv run jupyter lab # notebooks
make -C docs html # build docs
References
- Barndorff-Nielsen, O. E. (1977). Exponentially decreasing distributions for the logarithm of particle size.
- Eberlein, E., & Keller, U. (1995). Hyperbolic distributions in finance.
License
MIT — see LICENSE.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file normix-0.2.0.tar.gz.
File metadata
- Download URL: normix-0.2.0.tar.gz
- Upload date:
- Size: 57.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ecbccd5117f0336e92c50c021f29e76a5c17d1ecd71d89b9e363d4df435d3603
|
|
| MD5 |
933681d576eb2e3437f7eb3ecd81f17a
|
|
| BLAKE2b-256 |
cd9a9df3cffb6e1fd2827c7839fd30e88f29bb39fffe12a21a3147c7af8a6e48
|
Provenance
The following attestation bundles were made for normix-0.2.0.tar.gz:
Publisher:
publish.yml on xshi19/normix
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
normix-0.2.0.tar.gz -
Subject digest:
ecbccd5117f0336e92c50c021f29e76a5c17d1ecd71d89b9e363d4df435d3603 - Sigstore transparency entry: 1263167997
- Sigstore integration time:
-
Permalink:
xshi19/normix@15c54f695a9cb3a2013b0e517dad29ba5d212308 -
Branch / Tag:
refs/heads/master - Owner: https://github.com/xshi19
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@15c54f695a9cb3a2013b0e517dad29ba5d212308 -
Trigger Event:
workflow_dispatch
-
Statement type:
File details
Details for the file normix-0.2.0-py3-none-any.whl.
File metadata
- Download URL: normix-0.2.0-py3-none-any.whl
- Upload date:
- Size: 74.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4d54635468ffbc5af750d5360ba40379618c98716cc14bd2aa89853e05042131
|
|
| MD5 |
c848b4df7199cb2d77331fe31b4ce006
|
|
| BLAKE2b-256 |
f566783a00dfabdeaac011b1dbde7a4d38aee81ee65a06c4c25f2ec8fcb5aa14
|
Provenance
The following attestation bundles were made for normix-0.2.0-py3-none-any.whl:
Publisher:
publish.yml on xshi19/normix
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
normix-0.2.0-py3-none-any.whl -
Subject digest:
4d54635468ffbc5af750d5360ba40379618c98716cc14bd2aa89853e05042131 - Sigstore transparency entry: 1263168063
- Sigstore integration time:
-
Permalink:
xshi19/normix@15c54f695a9cb3a2013b0e517dad29ba5d212308 -
Branch / Tag:
refs/heads/master - Owner: https://github.com/xshi19
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@15c54f695a9cb3a2013b0e517dad29ba5d212308 -
Trigger Event:
workflow_dispatch
-
Statement type: