Skip to main content

Differentiable (binned) likelihoods in JAX.

Project description

logo

evermore

Documentation Status Actions Status PyPI version PyPI platforms

Differentiable (binned) likelihoods in JAX.

Installation

python -m pip install evermore

From source:

git clone https://github.com/pfackeldey/evermore
cd evermore
python -m pip install .

Usage - Model definition and fitting

See more in examples/

evermore in a nutshell:

import equinox as eqx
import jax
import jax.numpy as jnp

import evermore as evm

jax.config.update("jax_enable_x64", True)


# define a simple model with two processes and two parameters
class MyModel(evm.Model):
    def __call__(self, processes: dict, parameters: dict) -> evm.Result:
        res = evm.Result()

        # signal
        mu_mod = evm.modifier(
            name="mu", parameter=parameters["mu"], effect=evm.effect.unconstrained()
        )
        res.add(process="signal", expectation=mu_mod(processes["signal"]))

        # background
        bkg_mod = evm.modifier(
            name="sigma", parameter=parameters["sigma"], effect=evm.effect.gauss(0.2)
        )
        res.add(process="background", expectation=bkg_mod(processes["background"]))
        return res


# setup model
processes = {"signal": jnp.array([10.0]), "background": jnp.array([50.0])}
parameters = {
    "mu": evm.Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)),
    "sigma": evm.Parameter(value=jnp.array([0.0])),
}
model = MyModel(processes=processes, parameters=parameters)

# define negative log-likelihood with data (observation)
nll = evm.likelihood.NLL(model=model, observation=jnp.array([64.0]))
# jit it!
fast_nll = eqx.filter_jit(nll)

# setup fit: initial values of parameters and a suitable optimizer
init_values = model.parameter_values
optimizer = evm.optimizer.JaxOptimizer.make(
    name="ScipyMinimize", settings={"method": "trust-constr"}
)

# fit
values, state = optimizer.fit(fun=fast_nll, init_values=init_values)

print(values)
# -> {'mu': Array([1.4], dtype=float64),
#     'sigma': Array([4.04723836e-14], dtype=float64)}

# eval model with fitted values
print(model.update(values=values).evaluate().expectation())
# -> Array([64.], dtype=float64)


# gradients of "prefit" model:
print(eqx.filter_grad(nll)({"sigma": jnp.array([0.2])}))
# -> {'sigma': Array([-0.12258065], dtype=float64)}


# gradients of "postfit" model:
@eqx.filter_grad
@eqx.filter_jit
def grad_postfit_nll(where: dict[str, jax.Array]) -> dict[str, jax.Array]:
    nll = evm.likelihood.NLL(
        model=model.update(values=values), observation=jnp.array([64.0])
    )
    return nll(where)


print(grad_postfit_nll({"sigma": jnp.array([0.2])}))
# -> {'sigma': Array([0.5030303], dtype=float64)}

Contributing

See CONTRIBUTING.md for instructions on how to contribute.

License

Distributed under the terms of the BSD license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

evermore-0.2.1.tar.gz (173.9 kB view details)

Uploaded Source

Built Distribution

evermore-0.2.1-py3-none-any.whl (14.9 kB view details)

Uploaded Python 3

File details

Details for the file evermore-0.2.1.tar.gz.

File metadata

  • Download URL: evermore-0.2.1.tar.gz
  • Upload date:
  • Size: 173.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.2

File hashes

Hashes for evermore-0.2.1.tar.gz
Algorithm Hash digest
SHA256 6357ec61b36ee72b7ddeb2c7d773d8f9c32cb63f3ff0c4e685c36df0672a0eb1
MD5 ae3bcda04632821b38dfae1400f793e5
BLAKE2b-256 69158b111b15fb44d4f02ba43b884250bcccd4d5e2eee6d10a248f03cdc0be54

See more details on using hashes here.

File details

Details for the file evermore-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: evermore-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 14.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.2

File hashes

Hashes for evermore-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 799e577fd33d8d69d56dcf298f768005d38f1b1e859cd51bcfa814b75cb3b8cb
MD5 cfbe6df46cccf4b2f3364ba3daa173a4
BLAKE2b-256 1ce9d37bd4758a7bde7b991095ba622b1a7c149412cc17ce4dc39b9edfdb1ade

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page