Utility functions for extracting log-probabilities, parameter transforms, and Fisher information from NumPyro models.
Project description
numpyro-inferutils
Small utility functions for inference with NumPyro models.
This package provides lightweight helpers for:
- extracting log-prior and log-likelihood from NumPyro models,
- working with constrained / unconstrained parameter spaces,
- computing Fisher information matrices from NumPyro models with independent Gaussian likelihoods.
- performing MAP estimation using stochastic variational inference (SVI).
Installation
pip install numpyro-inferutils
Quick examples
A minimal NumPyro model
All examples below assume a simple NumPyro model such as:
import numpyro
import numpyro.distributions as dist
import numpy as np
x = np.linspace(-5, 5, 100)
sigma = np.ones_like(x) * np.exp(0.01)
y = 0.5 * x + 1.0 + np.random.randn(len(x)) * sigma
def model(x, y):
w = numpyro.sample("w", dist.Normal(0.0, 1.0))
b = numpyro.sample("b", dist.Normal(0.0, 1.0))
sigma = numpyro.sample("sigma", dist.LogNormal(0.0, 0.01))
mu = w * x + b
numpyro.deterministic("mu", mu)
numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)
Log-prior and log-likelihood
from numpyro_inferutils import build_logprob_functions
logprior, loglik = build_logprob_functions(model, model_kwargs={"x": x, "y": y})
theta = {
"w": 0.0,
"b": 1.2,
}
lp = logprior(theta)
ll = loglik(theta)
logprior(theta)sums log-probabilities from non-observed sample sites.loglik(theta)sums log-probabilities from observed sample sites.- Contributions added via
numpyro.factorare treated as part of the log-likelihood.
Constrained ↔ unconstrained parameters
from numpyro_inferutils.transforms import to_unconstrained_dict
params_constrained = {"sigma": 2.0}
params_unconstrained = to_unconstrained_dict(
model,
params_constrained,
keys=["sigma"],
x=x, y=y
)
This inspects the model’s sample-site supports and applies the appropriate inverse transforms using
biject_to(site["fn"].support)
MAP estimation via SVI
For many applications, it is useful to obtain a fast maximum a posteriori (MAP) estimate, for example as an initial point for NUTS.
import jax
from numpyro_inferutils import find_map_svi
rng_key = jax.random.PRNGKey(0)
p_map = find_map_svi(
model,
step_size=1e-2,
num_steps=5_000,
rng_key=rng_key,
x=x,
y=y,
)
- The MAP estimate is obtained via stochastic variational inference (SVI) using a Laplace autoguide (
AutoLaplaceApproximation). - Only a MAP-like point estimate (the guide median) is returned; the covariance of the Laplace approximation is intentionally not used.
- Parameter constraints defined in the NumPyro model are handled automatically.
- The returned parameters are in the constrained space.
Fisher information (independent Gaussian likelihood)
from numpyro_inferutils.fisher import information_from_model_independent_normal
info = information_from_model_independent_normal(
model=model,
pdic={"w": 1.0, "b": 0.5},
mu_name="mu",
observed=y,
model_args=(x, y),
keys=["w", "b"],
sigma_sd=sigma,
)
F = info["fisher"]
The Fisher matrix for an independent Gaussian likelihood is computed as
F = Jᵀ J,
where J_ij = ∂r_i / ∂θ_j and
r = (y − μ(θ)) / σ.
Both constrained and unconstrained parameterizations are supported.
License
MIT License.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file numpyro_inferutils-0.1.2.tar.gz.
File metadata
- Download URL: numpyro_inferutils-0.1.2.tar.gz
- Upload date:
- Size: 9.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8c555da373e7066fc3e4b320523ee4609139f6b4dcec9b3306bd14cbb93dc95a
|
|
| MD5 |
cedf647f7f5629dcd7e25f159f79ff2b
|
|
| BLAKE2b-256 |
9b829425cce85c8991ab151b47977e765e8791745eb2b51db639e6352d8a9a2f
|
Provenance
The following attestation bundles were made for numpyro_inferutils-0.1.2.tar.gz:
Publisher:
release.yml on kemasuda/numpyro-inferutils
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
numpyro_inferutils-0.1.2.tar.gz -
Subject digest:
8c555da373e7066fc3e4b320523ee4609139f6b4dcec9b3306bd14cbb93dc95a - Sigstore transparency entry: 765390470
- Sigstore integration time:
-
Permalink:
kemasuda/numpyro-inferutils@87f94eac5c85b818fee6a5b0cc90dee658417b2a -
Branch / Tag:
refs/tags/v0.1.2 - Owner: https://github.com/kemasuda
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@87f94eac5c85b818fee6a5b0cc90dee658417b2a -
Trigger Event:
push
-
Statement type:
File details
Details for the file numpyro_inferutils-0.1.2-py3-none-any.whl.
File metadata
- Download URL: numpyro_inferutils-0.1.2-py3-none-any.whl
- Upload date:
- Size: 9.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9767452707b15d35eabc48f96626d46d1b69022748ce897ad875bde903b47425
|
|
| MD5 |
4a1f96dee5841e320f663e40bd3e669c
|
|
| BLAKE2b-256 |
d480ff25580a993a7780b0118f71bab9d03f0270832b503e080169b5a08fb535
|
Provenance
The following attestation bundles were made for numpyro_inferutils-0.1.2-py3-none-any.whl:
Publisher:
release.yml on kemasuda/numpyro-inferutils
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
numpyro_inferutils-0.1.2-py3-none-any.whl -
Subject digest:
9767452707b15d35eabc48f96626d46d1b69022748ce897ad875bde903b47425 - Sigstore transparency entry: 765390479
- Sigstore integration time:
-
Permalink:
kemasuda/numpyro-inferutils@87f94eac5c85b818fee6a5b0cc90dee658417b2a -
Branch / Tag:
refs/tags/v0.1.2 - Owner: https://github.com/kemasuda
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@87f94eac5c85b818fee6a5b0cc90dee658417b2a -
Trigger Event:
push
-
Statement type: