Skip to main content

Differentiable (binned) likelihoods in JAX.

Project description

logo

evermore

Documentation Status Actions Status PyPI version PyPI platforms Conda-Forge BSD-3 Clause License

Differentiable (binned) likelihoods in JAX.

Installation

python -m pip install evermore

From source:

git clone https://github.com/pfackeldey/evermore
cd evermore
python -m pip install .

Example - Model and Loss Definition

See more in examples/

evermore in a nutshell:

from typing import NamedTuple, TypeAlias

from flax import nnx
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, PyTree, Scalar

import evermore as evm

jax.config.update("jax_enable_x64", True)

# type defs
Hist1D: TypeAlias = Float[Array, "..."]
Hists1D: TypeAlias = dict[str, Hist1D]
Args: TypeAlias = tuple[
    nnx.GraphDef,
    nnx.State,
    Hists1D,
    Hist1D,
]


# define a simple model with two processes and two parameters
def model(params: PyTree, hists: Hists1D) -> Array:
    mu_modifier = params.mu.scale()
    syst_modifier = params.syst.scale_log_asymmetric(up=1.1, down=0.9)
    return mu_modifier(hists["signal"]) + syst_modifier(hists["bkg"])


def loss(
    dynamic: nnx.State,
    args: Args,
) -> Float[Scalar, ""]:
    graphdef, static, hists, observation = args
    params = nnx.merge(graphdef, dynamic, static)
    expectation = model(params, hists)
    # Poisson NLL of the expectation and observation
    log_likelihood = (
        evm.pdf.PoissonContinuous(lamb=expectation).log_prob(observation).sum()
    )
    # Add parameter constraints from logpdfs
    constraints = evm.loss.get_log_probs(params)
    log_likelihood += evm.util.sum_over_leaves(constraints)
    return -jnp.sum(log_likelihood)


# setup data
hists: Hists1D = {"signal": jnp.array([3.0]), "bkg": jnp.array([10.0])}
observation: Hist1D = jnp.array([15.0])


# define parameters, can be any PyTree of evm.Parameters
class Params(NamedTuple):
    mu: evm.Parameter[Float[Scalar, ""]]
    syst: evm.NormalParameter[Float[Scalar, ""]]


params = Params(mu=evm.Parameter(1.0), syst=evm.NormalParameter(0.0))

# split tree of parameters in a differentiable part and a static part
graphdef, dynamic, static = nnx.split(params, evm.filter.is_dynamic_parameter, ...)
args = (graphdef, static, hists, observation)

# Calculate negative log-likelihood/loss
loss_val = loss(dynamic, args)
# gradients of negative log-likelihood w.r.t. dynamic parameters
grads = nnx.grad(loss)(dynamic, args)
nnx.display(nnx.pure(grads))
# State({
#   'mu': Array(-0.46153846, dtype=float64, weak_type=True),
#   'syst': Array(-0.15436207, dtype=float64, weak_type=True)
# })

Contributing

See CONTRIBUTING.md for instructions on how to contribute.

License

Distributed under the terms of the BSD license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

evermore-0.4.1.tar.gz (151.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

evermore-0.4.1-py3-none-any.whl (25.3 kB view details)

Uploaded Python 3

File details

Details for the file evermore-0.4.1.tar.gz.

File metadata

  • Download URL: evermore-0.4.1.tar.gz
  • Upload date:
  • Size: 151.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for evermore-0.4.1.tar.gz
Algorithm Hash digest
SHA256 a7a3f37a67bc27468f336ffba043e3264afac3d2772557e037838a978cdc8813
MD5 e5f646daff69fe5921f249c3a5f4317f
BLAKE2b-256 af409aead8ce4e8837a9e781ac3e57538bca9c613c4f5e278c5c83a9ca5462fb

See more details on using hashes here.

File details

Details for the file evermore-0.4.1-py3-none-any.whl.

File metadata

  • Download URL: evermore-0.4.1-py3-none-any.whl
  • Upload date:
  • Size: 25.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for evermore-0.4.1-py3-none-any.whl
Algorithm Hash digest
SHA256 c906ccac7d73de175ef875e086ceab56d430f34c22e861a4165c091ebad99ec4
MD5 fc743b845ee2175a9a0a4881005b256d
BLAKE2b-256 64135b3c0e1610beb3aa395261d737211cb4d8993999b19e45b6adbce5c1d2d7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page