Skip to main content

Differentiable (binned) likelihoods in JAX.

Project description

evermore

Documentation Status Actions Status PyPI version PyPI platforms

Differentiable (binned) likelihoods in JAX.

Installation

python -m pip install evermore

From source:

git clone https://github.com/pfackeldey/evermore
cd evermore
python -m pip install .

Usage - Model definition and fitting

See more in examples/

evermore in a nutshell:

import equinox as eqx
import jax
import jax.numpy as jnp

import evermore as evm

jax.config.update("jax_enable_x64", True)


# define a simple model with two processes and two parameters
class MyModel(evm.Model):
    def __call__(self, processes: dict, parameters: dict) -> evm.Result:
        res = evm.Result()

        # signal
        mu_mod = evm.modifier(
            name="mu", parameter=parameters["mu"], effect=evm.effect.unconstrained()
        )
        res.add(process="signal", expectation=mu_mod(processes["signal"]))

        # background
        bkg_mod = evm.modifier(
            name="sigma", parameter=parameters["sigma"], effect=evm.effect.gauss(0.2)
        )
        res.add(process="background", expectation=bkg_mod(processes["background"]))
        return res


# setup model
processes = {"signal": jnp.array([10.0]), "background": jnp.array([50.0])}
parameters = {
    "mu": evm.Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)),
    "sigma": evm.Parameter(value=jnp.array([0.0])),
}
model = MyModel(processes=processes, parameters=parameters)

# define negative log-likelihood with data (observation)
nll = evm.likelihood.NLL(model=model, observation=jnp.array([64.0]))
# jit it!
fast_nll = eqx.filter_jit(nll)

# setup fit: initial values of parameters and a suitable optimizer
init_values = model.parameter_values
optimizer = evm.optimizer.JaxOptimizer.make(
    name="ScipyMinimize", settings={"method": "trust-constr"}
)

# fit
values, state = optimizer.fit(fun=fast_nll, init_values=init_values)

print(values)
# -> {'mu': Array([1.4], dtype=float64),
#     'sigma': Array([4.04723836e-14], dtype=float64)}

# eval model with fitted values
print(model.update(values=values).evaluate().expectation())
# -> Array([64.], dtype=float64)


# gradients of "prefit" model:
print(eqx.filter_grad(nll)({"sigma": jnp.array([0.2])}))
# -> {'sigma': Array([-0.12258065], dtype=float64)}


# gradients of "postfit" model:
@eqx.filter_grad
@eqx.filter_jit
def grad_postfit_nll(where: dict[str, jax.Array]) -> dict[str, jax.Array]:
    nll = evm.likelihood.NLL(
        model=model.update(values=values), observation=jnp.array([64.0])
    )
    return nll(where)


print(grad_postfit_nll({"sigma": jnp.array([0.2])}))
# -> {'sigma': Array([0.5030303], dtype=float64)}

Contributing

See CONTRIBUTING.md for instructions on how to contribute.

License

Distributed under the terms of the BSD license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

evermore-0.2.0.tar.gz (41.5 kB view details)

Uploaded Source

Built Distribution

evermore-0.2.0-py3-none-any.whl (19.8 kB view details)

Uploaded Python 3

File details

Details for the file evermore-0.2.0.tar.gz.

File metadata

  • Download URL: evermore-0.2.0.tar.gz
  • Upload date:
  • Size: 41.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.11.0

File hashes

Hashes for evermore-0.2.0.tar.gz
Algorithm Hash digest
SHA256 b8136c557581b5ef6e44f11e75f51cc18d7ed03ed459e65c300eeadf1b7b1a4c
MD5 1adb78382f9b71de116369f5ad1466dc
BLAKE2b-256 aabbafdb3fa525f721440e2226112d38871ec177034c39916f9ee0deebf8b800

See more details on using hashes here.

File details

Details for the file evermore-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: evermore-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 19.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.11.0

File hashes

Hashes for evermore-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 62bba15c38f708cf8d27dfe294ae887119d62e89a51cd1f5c8a799e4c0df18ab
MD5 11723e0515ed465849809fd3eaae4ab7
BLAKE2b-256 2255ed2a00e605e8a9128ead456a81b9cf577888ed3830a3a96315cbdfd384b2

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page