Skip to main content

Differentiable (binned) likelihoods in JAX.

Project description

logo

evermore

Documentation Status Actions Status PyPI version PyPI platforms

Differentiable (binned) likelihoods in JAX.

Installation

python -m pip install evermore

From source:

git clone https://github.com/pfackeldey/evermore
cd evermore
python -m pip install .

Example - Model and Loss Definition

See more in examples/

evermore in a nutshell:

from typing import NamedTuple

import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array, PyTree

import evermore as evm

jax.config.update("jax_enable_x64", True)


# define a simple model with two processes and two parameters
def model(params: PyTree, hists: dict[str, Array]) -> Array:
  mu_modifier = params.mu.scale()
  syst_modifier = params.syst.scale_log(up=1.1, down=0.9)
  return mu_modifier(hists["signal"]) + syst_modifier(hists["bkg"])


def loss(
  diffable: PyTree,
  static: PyTree,
  hists: dict[str, Array],
  observation: Array,
) -> Array:
    params = eqx.combine(diffable, static)
    expectation = model(params, hists)
    # Poisson NLL of the expectation and observation
    log_likelihood = evm.pdf.Poisson(lamb=expectation).log_prob(observation)
    # Add parameter constraints from logpdfs
    constraints = evm.loss.get_log_probs(params)
    log_likelihood += evm.util.sum_over_leaves(constraints)
    return -jnp.sum(log_likelihood)


# setup data
hists = {"signal": jnp.array([3]), "bkg": jnp.array([10])}
observation = jnp.array([15])


# define parameters, can be any PyTree of evm.Parameters
class Params(NamedTuple):
  mu: evm.Parameter
  syst: evm.NormalParameter


params = Params(mu=evm.Parameter(1.0), syst=evm.NormalParameter(0.0))
diffable, static = evm.parameter.partition(params)

# Calculate negative log-likelihood/loss
loss_val = loss(diffable, static, hists, observation)
# gradients of negative log-likelihood w.r.t. diffable parameters
grads = eqx.filter_grad(loss)(diffable, static, hists, observation)
print(f"{grads.mu.value=}, {grads.syst.value=}")
# -> grads.mu.value=Array([-0.46153846]), grads.syst.value=Array([-0.15436207])

Contributing

See CONTRIBUTING.md for instructions on how to contribute.

License

Distributed under the terms of the BSD license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

evermore-0.2.9.tar.gz (141.6 kB view details)

Uploaded Source

Built Distribution

evermore-0.2.9-py3-none-any.whl (19.6 kB view details)

Uploaded Python 3

File details

Details for the file evermore-0.2.9.tar.gz.

File metadata

  • Download URL: evermore-0.2.9.tar.gz
  • Upload date:
  • Size: 141.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for evermore-0.2.9.tar.gz
Algorithm Hash digest
SHA256 19972378ab9de015ed3ea7c38db93d876b7d80a835d78c9e110c1452f04de592
MD5 2a493ec3837b74fe53d24e4b4bbb89c3
BLAKE2b-256 8a63568d390a5e7d53039f05156ea867b7f5d73dee244a879a1e7df496b2f60f

See more details on using hashes here.

File details

Details for the file evermore-0.2.9-py3-none-any.whl.

File metadata

  • Download URL: evermore-0.2.9-py3-none-any.whl
  • Upload date:
  • Size: 19.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for evermore-0.2.9-py3-none-any.whl
Algorithm Hash digest
SHA256 a42bbdb1bc07a15f93ce2dc63f68e8dc60edd00c6a7efe77fa18f91506fc0dec
MD5 e6066d52a473d64c12fc7ed9cd9afd29
BLAKE2b-256 602648546a948bcd66415e939fabb79f032b405f4994c748739a3781505cdf75

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page