Differentiable (binned) likelihoods in JAX.
Project description
dilax
Differentiable (binned) likelihoods in JAX.
Installation
python -m pip install dilax
From source:
git clone https://github.com/pfackeldey/dilax
cd dilax
python -m pip install .
Usage - Model definition and fitting
See more in examples/
dilax in a nutshell:
import jax
import jax.numpy as jnp
import equinox as eqx
from dilax.likelihood import NLL
from dilax.model import Model, Result
from dilax.optimizer import JaxOptimizer
from dilax.parameter import Parameter, gauss, modifier, unconstrained
jax.config.update("jax_enable_x64", True)
# define a simple model with two processes and two parameters
class MyModel(Model):
def __call__(self, processes: dict, parameters: dict[str, Parameter]) -> Result:
res = Result()
# signal
mu_mod = modifier(name="mu", parameter=parameters["mu"], effect=unconstrained())
res.add(process="signal", expectation=mu_mod(processes["signal"]))
# background
bkg_mod = modifier(name="sigma", parameter=parameters["sigma"], effect=gauss(0.2))
res.add(process="background", expectation=bkg_mod(processes["background"]))
return res
# setup model
processes = {"signal": jnp.array([10.0]), "background": jnp.array([50.0])}
parameters = {
"mu": Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)),
"sigma": Parameter(value=jnp.array([0.0])),
}
model = MyModel(processes=processes, parameters=parameters)
# define negative log-likelihood with data (observation)
nll = NLL(model=model, observation=jnp.array([64.0]))
# jit it!
fast_nll = eqx.filter_jit(nll)
# setup fit: initial values of parameters and a suitable optimizer
init_values = model.parameter_values
optimizer = JaxOptimizer.make(name="ScipyMinimize", settings={"method": "trust-constr"})
# fit
values, state = optimizer.fit(fun=fast_nll, init_values=init_values)
print(values)
# -> {'mu': Array([1.4], dtype=float64),
# 'sigma': Array([4.04723836e-14], dtype=float64)}
# eval model with fitted values/parameters
print(model.update(values=values).evaluate().expectation())
# -> Array([64.], dtype=float64)
# gradients of "prefit" model:
fast_grad_nll_prefit = eqx.filter_grad(nll)
print(fast_grad_nll_prefit({"sigma": jnp.array([0.2])}))
# -> {'sigma': Array([-0.12258065], dtype=float64)}
# gradients of "postfit" model:
postfit_nll = NLL(model=model.update(values=values), observation=jnp.array([64.0]))
fast_grad_nll_postfit = eqx.filter_grad(eqx.filter_jit(postfit_nll))
print(fast_grad_nll_postfit({"sigma": jnp.array([0.2])}))
# -> {'sigma': Array([0.5030303], dtype=float64)}
Contributing
See CONTRIBUTING.md for instructions on how to contribute.
License
Distributed under the terms of the BSD license.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
dilax-0.1.5.tar.gz
(22.6 kB
view details)
Built Distribution
dilax-0.1.5-py3-none-any.whl
(16.2 kB
view details)
File details
Details for the file dilax-0.1.5.tar.gz
.
File metadata
- Download URL: dilax-0.1.5.tar.gz
- Upload date:
- Size: 22.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1f672945cc0db1a89d93da5b5b7aeb429aa116460dc92c80dc78b891adbb5c3e |
|
MD5 | 92d1145d65f96f3ea1fa671dbaf48de1 |
|
BLAKE2b-256 | 0076352fb183dfe5259f4cc41342844a7be3d0112fe292b2564e1cecb36484fc |
File details
Details for the file dilax-0.1.5-py3-none-any.whl
.
File metadata
- Download URL: dilax-0.1.5-py3-none-any.whl
- Upload date:
- Size: 16.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | dc86304216380bcaa9b21fca0bc8501f8091cf8adcdbe0c044ded06648077a96 |
|
MD5 | 1649b6be85b8065ff63bdc7607850efb |
|
BLAKE2b-256 | 5ef3806de879c2f721fc540379234c625ed3da1c9c069a1c96451643013d9d4f |