Differentiable (binned) likelihoods in JAX.
Project description
evermore
Differentiable (binned) likelihoods in JAX.
Installation
python -m pip install evermore
From source:
git clone https://github.com/pfackeldey/evermore
cd evermore
python -m pip install .
Example - Model and Loss Definition
See more in examples/
evermore in a nutshell:
from typing import NamedTuple, TypeAlias
from flax import nnx
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, PyTree, Scalar
import evermore as evm
jax.config.update("jax_enable_x64", True)
# type defs
Hist1D: TypeAlias = Float[Array, "..."]
Hists1D: TypeAlias = dict[str, Hist1D]
Args: TypeAlias = tuple[
nnx.GraphDef,
nnx.State,
Hists1D,
Hist1D,
]
# define a simple model with two processes and two parameters
def model(params: PyTree, hists: Hists1D) -> Array:
mu_modifier = params.mu.scale()
syst_modifier = params.syst.scale_log_asymmetric(up=1.1, down=0.9)
return mu_modifier(hists["signal"]) + syst_modifier(hists["bkg"])
def loss(
dynamic: nnx.State,
args: Args,
) -> Float[Scalar, ""]:
graphdef, static, hists, observation = args
params = nnx.merge(graphdef, dynamic, static)
expectation = model(params, hists)
# Poisson NLL of the expectation and observation
log_likelihood = (
evm.pdf.PoissonContinuous(lamb=expectation).log_prob(observation).sum()
)
# Add parameter constraints from logpdfs
constraints = evm.loss.get_log_probs(params)
log_likelihood += evm.util.sum_over_leaves(constraints)
return -jnp.sum(log_likelihood)
# setup data
hists: Hists1D = {"signal": jnp.array([3.0]), "bkg": jnp.array([10.0])}
observation: Hist1D = jnp.array([15.0])
# define parameters, can be any PyTree of evm.Parameters
class Params(NamedTuple):
mu: evm.Parameter[Float[Scalar, ""]]
syst: evm.NormalParameter[Float[Scalar, ""]]
params = Params(mu=evm.Parameter(1.0), syst=evm.NormalParameter(0.0))
# split tree of parameters in a differentiable part and a static part
graphdef, dynamic, static = nnx.split(params, evm.filter.is_dynamic_parameter, ...)
args = (graphdef, static, hists, observation)
# Calculate negative log-likelihood/loss
loss_val = loss(dynamic, args)
# gradients of negative log-likelihood w.r.t. dynamic parameters
grads = nnx.grad(loss)(dynamic, args)
nnx.display(nnx.pure(grads))
# State({
# 'mu': Array(-0.46153846, dtype=float64, weak_type=True),
# 'syst': Array(-0.15436207, dtype=float64, weak_type=True)
# })
Contributing
See CONTRIBUTING.md for instructions on how to contribute.
License
Distributed under the terms of the BSD license.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
evermore-0.4.1.tar.gz
(151.0 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
evermore-0.4.1-py3-none-any.whl
(25.3 kB
view details)
File details
Details for the file evermore-0.4.1.tar.gz.
File metadata
- Download URL: evermore-0.4.1.tar.gz
- Upload date:
- Size: 151.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a7a3f37a67bc27468f336ffba043e3264afac3d2772557e037838a978cdc8813
|
|
| MD5 |
e5f646daff69fe5921f249c3a5f4317f
|
|
| BLAKE2b-256 |
af409aead8ce4e8837a9e781ac3e57538bca9c613c4f5e278c5c83a9ca5462fb
|
File details
Details for the file evermore-0.4.1-py3-none-any.whl.
File metadata
- Download URL: evermore-0.4.1-py3-none-any.whl
- Upload date:
- Size: 25.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c906ccac7d73de175ef875e086ceab56d430f34c22e861a4165c091ebad99ec4
|
|
| MD5 |
fc743b845ee2175a9a0a4881005b256d
|
|
| BLAKE2b-256 |
64135b3c0e1610beb3aa395261d737211cb4d8993999b19e45b6adbce5c1d2d7
|