Differentiable (binned) likelihoods in JAX.
Project description
evermore
Differentiable (binned) likelihoods in JAX.
Installation
python -m pip install evermore
From source:
git clone https://github.com/pfackeldey/evermore
cd evermore
python -m pip install .
Example - Model and Loss Definition
See more in examples/
evermore in a nutshell:
import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array
import evermore as evm
jax.config.update("jax_enable_x64", True)
# define a simple model with two processes and two parameters
class Model(eqx.Module):
mu: evm.FreeFloating
syst: evm.NormalConstrained
def __call__(self, hists: dict[str, Array]) -> Array:
mu_modifier = self.mu.unconstrained()
syst_modifier = self.syst.log_normal(up=jnp.array([1.1]), down=jnp.array([0.9]))
return mu_modifier(hists["signal"]) + syst_modifier(hists["bkg"])
nll = evm.loss.PoissonNLL()
def loss(model: Model, hists: dict[str, Array], observation: Array) -> Array:
expectation = model(hists)
# Poisson NLL of the expectation and observation
log_likelihood = nll(expectation, observation)
# Add parameter constraints from logpdfs
constraints = evm.loss.get_log_probs(model)
log_likelihood += evm.util.sum_leaves(constraints)
return -jnp.sum(log_likelihood)
# setup model and data
hists = {"signal": jnp.array([3]), "bkg": jnp.array([10])}
observation = jnp.array([15])
model = Model(mu=evm.FreeFloating(), syst=evm.NormalConstrained())
# negative log-likelihood
loss_val = loss(model, hists, observation)
# gradients of negative log-likelihood w.r.t. model parameters
grads = eqx.filter_grad(loss)(model, hists, observation)
print(f"{grads.mu.value=}, {grads.syst.value=}")
# -> grads.mu.value=Array([-0.46153846]), grads.syst.value=Array([-0.15436207])
Contributing
See CONTRIBUTING.md for instructions on how to contribute.
License
Distributed under the terms of the BSD license.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
evermore-0.2.4.tar.gz
(130.6 kB
view details)
Built Distribution
evermore-0.2.4-py3-none-any.whl
(16.0 kB
view details)
File details
Details for the file evermore-0.2.4.tar.gz
.
File metadata
- Download URL: evermore-0.2.4.tar.gz
- Upload date:
- Size: 130.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.12.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 58a691ea9ffff71017864e6ccef53f4abea16159b137cdcf5ec8f1cf97e12830 |
|
MD5 | 113aac648084b87b22cd38ffb30f27c4 |
|
BLAKE2b-256 | d847c4c7f96d8c395c16bdbbf5cb9903f45c59b16dcfbc54a9bd10752e1207d6 |
File details
Details for the file evermore-0.2.4-py3-none-any.whl
.
File metadata
- Download URL: evermore-0.2.4-py3-none-any.whl
- Upload date:
- Size: 16.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.12.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | cb4d33416c59f2b4504be0ea1f1f3dd304421afce3d4ecab9523c71696ff7376 |
|
MD5 | 3f622ccdfcd434e2bae5bfa10c2de6e9 |
|
BLAKE2b-256 | b2376f66d592255646962397b85ff76a10e7a334a5c71b1d3195c2e2706fd16c |