Utility functions for extracting log-probabilities, parameter transforms, and Fisher information from NumPyro models.
Project description
numpyro-inferutils
Small utility functions for inference with NumPyro models.
This package provides lightweight helpers for:
- extracting log-prior and log-likelihood from NumPyro models,
- working with constrained / unconstrained parameter spaces,
- computing Fisher information matrices from NumPyro models with independent Gaussian likelihoods.
Installation
pip install numpyro-inferutils
Quick examples
Log-prior and log-likelihood
from numpyro_inferutils import build_logprob_functions
logprior, loglik = build_logprob_functions(model)
theta = {
"x": 0.0,
"y": 1.2,
}
lp = logprior(theta)
ll = loglik(theta)
logprior(theta)sums log-probabilities from non-observed sample sites.loglik(theta)sums log-probabilities from observed sample sites.- Contributions added via
numpyro.factorare treated as part of the log-likelihood.
Constrained ↔ unconstrained parameters
from numpyro_inferutils.transforms import to_unconstrained_dict
params_constrained = {"sigma": 2.0}
params_unconstrained = to_unconstrained_dict(
model,
params_constrained,
keys=["sigma"],
)
This inspects the model’s sample-site supports and applies the appropriate inverse transforms using
biject_to(site["fn"].support)
Seeding and substituting parameters
from jax import random
from numpyro_inferutils.transforms import seed_and_substitute
rng_key = random.PRNGKey(0)
model_sub = seed_and_substitute(
model,
params_dict={"sigma": 0.5},
param_space="unconstrained",
rng_key=rng_key,
)
- If
param_space="unconstrained", parameters are interpreted as living in unconstrained space and mapped to constrained space using NumPyro’s internal unconstraining reparameterization. - If
param_space="constrained", values are substituted directly.
Fisher information (independent Gaussian likelihood)
from numpyro_inferutils.fisher import information_from_model_independent_normal
info = information_from_model_independent_normal(
model=model,
pdic={"w": 1.0, "b": 0.0},
mu_name="mu",
observed=y_obs,
keys=["w", "b"],
sigma_sd=sigma,
)
F = info["fisher"]
The Fisher matrix is approximated as
F ≈ Jᵀ J,
where J_ij = ∂r_i / ∂θ_j and
r = (y − μ(θ)) / σ.
Both constrained and unconstrained parameterizations are supported.
License
MIT License.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file numpyro_inferutils-0.1.0.tar.gz.
File metadata
- Download URL: numpyro_inferutils-0.1.0.tar.gz
- Upload date:
- Size: 7.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
242c9aad00fd2a9a3944f42b9f470d48ade2d887d8cff43f2ec4fbf584f95a31
|
|
| MD5 |
225d3f69da53fbd1380b683a42c677bf
|
|
| BLAKE2b-256 |
9547e736ed4117d48e14f647b478e65b043c930de56143c527840559511aa492
|
File details
Details for the file numpyro_inferutils-0.1.0-py3-none-any.whl.
File metadata
- Download URL: numpyro_inferutils-0.1.0-py3-none-any.whl
- Upload date:
- Size: 2.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
caac5129ccb6b346bf3052c1c5a1fbf4650d85d2df7c4a2ced140020f7eff0a6
|
|
| MD5 |
6e41adca6219def9ce97aa6db1005f07
|
|
| BLAKE2b-256 |
5b27d822667d033939db28cd37d22b6f676965a4ff56035da6a21c1be1ffacd8
|