Skip to main content

Differentiable (binned) likelihoods in JAX.

Project description

dilax

Documentation Status Actions Status PyPI version PyPI platforms

Differentiable (binned) likelihoods in JAX.

Installation

python -m pip install dilax

From source:

git clone https://github.com/pfackeldey/dilax
cd dilax
python -m pip install .

Usage - Model definition and fitting

See more in examples/

dilax in a nutshell:

import jax
import jax.numpy as jnp
import equinox as eqx

from dilax.likelihood import NLL
from dilax.model import Model, Result
from dilax.optimizer import JaxOptimizer
from dilax.parameter import Parameter, gauss, modifier, unconstrained


jax.config.update("jax_enable_x64", True)


# define a simple model with two processes and two parameters
class MyModel(Model):
    def __call__(self, processes: dict, parameters: dict[str, Parameter]) -> Result:
        res = Result()

        # signal
        mu_mod = modifier(name="mu", parameter=parameters["mu"], effect=unconstrained())
        res.add(process="signal", expectation=mu_mod(processes["signal"]))

        # background
        bkg_mod = modifier(name="sigma", parameter=parameters["sigma"], effect=gauss(0.2))
        res.add(process="background", expectation=bkg_mod(processes["background"]))
        return res


# setup model
processes = {"signal": jnp.array([10.0]), "background": jnp.array([50.0])}
parameters = {
    "mu": Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)),
    "sigma": Parameter(value=jnp.array([0.0])),
}
model = MyModel(processes=processes, parameters=parameters)

# define negative log-likelihood with data (observation)
nll = NLL(model=model, observation=jnp.array([64.0]))
# jit it!
fast_nll = eqx.filter_jit(nll)

# setup fit: initial values of parameters and a suitable optimizer
init_values = model.parameter_values
optimizer = JaxOptimizer.make(name="ScipyMinimize", settings={"method": "trust-constr"})

# fit
values, state = optimizer.fit(fun=fast_nll, init_values=init_values)

print(values)
# -> {'mu': Array([1.4], dtype=float64),
#     'sigma': Array([4.04723836e-14], dtype=float64)}

# eval model with fitted values/parameters
print(model.update(values=values).evaluate().expectation())
# -> Array([64.], dtype=float64)


# gradients of "prefit" model:
fast_grad_nll_prefit = eqx.filter_grad(nll)
print(fast_grad_nll_prefit({"sigma": jnp.array([0.2])}))
# -> {'sigma': Array([-0.12258065], dtype=float64)}

# gradients of "postfit" model:
postfit_nll = NLL(model=model.update(values=values), observation=jnp.array([64.0]))
fast_grad_nll_postfit = eqx.filter_grad(eqx.filter_jit(postfit_nll))
print(fast_grad_nll_postfit({"sigma": jnp.array([0.2])}))
# -> {'sigma': Array([0.5030303], dtype=float64)}

Contributing

See CONTRIBUTING.md for instructions on how to contribute.

License

Distributed under the terms of the BSD license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dilax-0.1.5.tar.gz (22.6 kB view details)

Uploaded Source

Built Distribution

dilax-0.1.5-py3-none-any.whl (16.2 kB view details)

Uploaded Python 3

File details

Details for the file dilax-0.1.5.tar.gz.

File metadata

  • Download URL: dilax-0.1.5.tar.gz
  • Upload date:
  • Size: 22.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.6

File hashes

Hashes for dilax-0.1.5.tar.gz
Algorithm Hash digest
SHA256 1f672945cc0db1a89d93da5b5b7aeb429aa116460dc92c80dc78b891adbb5c3e
MD5 92d1145d65f96f3ea1fa671dbaf48de1
BLAKE2b-256 0076352fb183dfe5259f4cc41342844a7be3d0112fe292b2564e1cecb36484fc

See more details on using hashes here.

File details

Details for the file dilax-0.1.5-py3-none-any.whl.

File metadata

  • Download URL: dilax-0.1.5-py3-none-any.whl
  • Upload date:
  • Size: 16.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.6

File hashes

Hashes for dilax-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 dc86304216380bcaa9b21fca0bc8501f8091cf8adcdbe0c044ded06648077a96
MD5 1649b6be85b8065ff63bdc7607850efb
BLAKE2b-256 5ef3806de879c2f721fc540379234c625ed3da1c9c069a1c96451643013d9d4f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page