Differentiable (binned) likelihoods in JAX.
Project description
evermore
Differentiable (binned) likelihoods in JAX.
Installation
python -m pip install evermore
From source:
git clone https://github.com/pfackeldey/evermore
cd evermore
python -m pip install .
Usage - Model definition and fitting
See more in examples/
evermore in a nutshell:
import equinox as eqx
import jax
import jax.numpy as jnp
import evermore as evm
jax.config.update("jax_enable_x64", True)
# define a simple model with two processes and two parameters
class MyModel(evm.Model):
def __call__(self, processes: dict, parameters: dict) -> evm.Result:
res = evm.Result()
# signal
mu_mod = evm.modifier(
name="mu", parameter=parameters["mu"], effect=evm.effect.unconstrained()
)
res.add(process="signal", expectation=mu_mod(processes["signal"]))
# background
bkg_mod = evm.modifier(
name="sigma", parameter=parameters["sigma"], effect=evm.effect.gauss(0.2)
)
res.add(process="background", expectation=bkg_mod(processes["background"]))
return res
# setup model
processes = {"signal": jnp.array([10.0]), "background": jnp.array([50.0])}
parameters = {
"mu": evm.Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)),
"sigma": evm.Parameter(value=jnp.array([0.0])),
}
model = MyModel(processes=processes, parameters=parameters)
# define negative log-likelihood with data (observation)
nll = evm.likelihood.NLL(model=model, observation=jnp.array([64.0]))
# jit it!
fast_nll = eqx.filter_jit(nll)
# setup fit: initial values of parameters and a suitable optimizer
init_values = model.parameter_values
optimizer = evm.optimizer.JaxOptimizer.make(
name="ScipyMinimize", settings={"method": "trust-constr"}
)
# fit
values, state = optimizer.fit(fun=fast_nll, init_values=init_values)
print(values)
# -> {'mu': Array([1.4], dtype=float64),
# 'sigma': Array([4.04723836e-14], dtype=float64)}
# eval model with fitted values
print(model.update(values=values).evaluate().expectation())
# -> Array([64.], dtype=float64)
# gradients of "prefit" model:
print(eqx.filter_grad(nll)({"sigma": jnp.array([0.2])}))
# -> {'sigma': Array([-0.12258065], dtype=float64)}
# gradients of "postfit" model:
@eqx.filter_grad
@eqx.filter_jit
def grad_postfit_nll(where: dict[str, jax.Array]) -> dict[str, jax.Array]:
nll = evm.likelihood.NLL(
model=model.update(values=values), observation=jnp.array([64.0])
)
return nll(where)
print(grad_postfit_nll({"sigma": jnp.array([0.2])}))
# -> {'sigma': Array([0.5030303], dtype=float64)}
Contributing
See CONTRIBUTING.md for instructions on how to contribute.
License
Distributed under the terms of the BSD license.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
evermore-0.2.1.tar.gz
(173.9 kB
view details)
Built Distribution
evermore-0.2.1-py3-none-any.whl
(14.9 kB
view details)
File details
Details for the file evermore-0.2.1.tar.gz
.
File metadata
- Download URL: evermore-0.2.1.tar.gz
- Upload date:
- Size: 173.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.12.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6357ec61b36ee72b7ddeb2c7d773d8f9c32cb63f3ff0c4e685c36df0672a0eb1 |
|
MD5 | ae3bcda04632821b38dfae1400f793e5 |
|
BLAKE2b-256 | 69158b111b15fb44d4f02ba43b884250bcccd4d5e2eee6d10a248f03cdc0be54 |
File details
Details for the file evermore-0.2.1-py3-none-any.whl
.
File metadata
- Download URL: evermore-0.2.1-py3-none-any.whl
- Upload date:
- Size: 14.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.12.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 799e577fd33d8d69d56dcf298f768005d38f1b1e859cd51bcfa814b75cb3b8cb |
|
MD5 | cfbe6df46cccf4b2f3364ba3daa173a4 |
|
BLAKE2b-256 | 1ce9d37bd4758a7bde7b991095ba622b1a7c149412cc17ce4dc39b9edfdb1ade |