Skip to main content

Differentiable (binned) likelihoods in JAX.

Project description

logo

evermore

Documentation Status Actions Status PyPI version PyPI platforms

Differentiable (binned) likelihoods in JAX.

Installation

python -m pip install evermore

From source:

git clone https://github.com/pfackeldey/evermore
cd evermore
python -m pip install .

Example - Model and Loss Definition

See more in examples/

evermore in a nutshell:

from typing import NamedTuple

import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array, PyTree

import evermore as evm

jax.config.update("jax_enable_x64", True)


# define a simple model with two processes and two parameters
def model(params: PyTree, hists: dict[str, Array]) -> Array:
  mu_modifier = params.mu.scale()
  syst_modifier = params.syst.scale_log(up=1.1, down=0.9)
  return mu_modifier(hists["signal"]) + syst_modifier(hists["bkg"])


def loss(
  diffable: PyTree,
  static: PyTree,
  hists: dict[str, Array],
  observation: Array,
) -> Array:
    params = eqx.combine(diffable, static)
    expectation = model(params, hists)
    # Poisson NLL of the expectation and observation
    log_likelihood = evm.loss.PoissonLogLikelihood()(expectation, observation)
    # Add parameter constraints from logpdfs
    constraints = evm.loss.get_log_probs(params)
    log_likelihood += evm.util.sum_over_leaves(constraints)
    return -jnp.sum(log_likelihood)


# setup data
hists = {"signal": jnp.array([3]), "bkg": jnp.array([10])}
observation = jnp.array([15])


# define parameters, can be any PyTree of evm.Parameters
class Params(NamedTuple):
  mu: evm.Parameter
  syst: evm.NormalParameter


params = Params(mu=evm.Parameter(1.0), syst=evm.NormalParameter(0.0))
diffable, static = evm.parameter.partition(params)

# Calculate negative log-likelihood/loss
loss_val = loss(diffable, static, hists, observation)
# gradients of negative log-likelihood w.r.t. diffable parameters
grads = eqx.filter_grad(loss)(diffable, static, hists, observation)
print(f"{grads.mu.value=}, {grads.syst.value=}")
# -> grads.mu.value=Array([-0.46153846]), grads.syst.value=Array([-0.15436207])

Contributing

See CONTRIBUTING.md for instructions on how to contribute.

License

Distributed under the terms of the BSD license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

evermore-0.2.7.tar.gz (141.5 kB view details)

Uploaded Source

Built Distribution

evermore-0.2.7-py3-none-any.whl (19.7 kB view details)

Uploaded Python 3

File details

Details for the file evermore-0.2.7.tar.gz.

File metadata

  • Download URL: evermore-0.2.7.tar.gz
  • Upload date:
  • Size: 141.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.3

File hashes

Hashes for evermore-0.2.7.tar.gz
Algorithm Hash digest
SHA256 a56b6321f0ed22c1eaeb2679baf2d7b12d0c9c8374486e46b11b97cfd4f6c2c8
MD5 7a2716eb3d9b0481ed66681d28e89a31
BLAKE2b-256 4f62936ebf2dfab63321efd6a479b917d1df6f93d5d3287b926aea479b929dd5

See more details on using hashes here.

File details

Details for the file evermore-0.2.7-py3-none-any.whl.

File metadata

  • Download URL: evermore-0.2.7-py3-none-any.whl
  • Upload date:
  • Size: 19.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.3

File hashes

Hashes for evermore-0.2.7-py3-none-any.whl
Algorithm Hash digest
SHA256 4d9101136dcf7c01049b1dfbdff39877d044ad8e8f230901ea40a1214c0c2bbd
MD5 e767898f06e9e5209cacb2065e0eb394
BLAKE2b-256 9571dc4fb1dca53550397f277faa82105f869f8b7dac76e90619e6a23b89cb4d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page