Skip to main content

Differentiable (binned) likelihoods in JAX.

Project description

dilax

Documentation Status Actions Status PyPI version PyPI platforms

Differentiable (binned) likelihoods in JAX.

Installation

python -m pip install dilax

From source:

git clone https://github.com/pfackeldey/dilax
cd dilax
python -m pip install .

Usage - Model definition and fitting

See more in examples/

dilax in a nutshell:

import jax
import jax.numpy as jnp
import equinox as eqx

from dilax.likelihood import NLL
from dilax.model import Model, Result
from dilax.optimizer import JaxOptimizer
from dilax.parameter import Parameter, gauss, modifier, unconstrained
from dilax.util import HistDB


jax.config.update("jax_enable_x64", True)


# define a simple model with two processes and two parameters
class MyModel(Model):
    def __call__(self, processes: HistDB, parameters: dict[str, Parameter]) -> Result:
        res = Result()

        # signal
        mu_mod = modifier(name="mu", parameter=parameters["mu"], effect=unconstrained())
        res.add(process="signal", expectation=mu_mod(processes["signal"]))

        # background
        bkg_mod = modifier(name="sigma", parameter=parameters["sigma"], effect=gauss(0.2))
        res.add(process="background", expectation=bkg_mod(processes["background"]))
        return res


# setup model
processes = HistDB({"signal": jnp.array([10.0]), "background": jnp.array([50.0])})
parameters = {
    "mu": Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)),
    "sigma": Parameter(value=jnp.array([0.0])),
}
model = MyModel(processes=processes, parameters=parameters)

# define negative log-likelihood with data (observation)
nll = NLL(model=model, observation=jnp.array([64.0]))
# jit it!
fast_nll = eqx.filter_jit(nll)

# setup fit: initial values of parameters and a suitable optimizer
init_values = model.parameter_values
optimizer = JaxOptimizer.make(name="ScipyMinimize", settings={"method": "trust-constr"})

# fit
values, state = optimizer.fit(fun=fast_nll, init_values=init_values)

print(values)
# -> {'mu': Array([1.4], dtype=float64),
#     'sigma': Array([4.04723836e-14], dtype=float64)}

# eval model with fitted values/parameters
print(model.update(values=values).evaluate().expectation())
# -> Array([64.], dtype=float64)


# gradients of "prefit" model:
fast_grad_nll_prefit = eqx.filter_grad(nll)
print(fast_grad_nll_prefit({"sigma": jnp.array([0.2])}))
# -> {'sigma': Array([-0.12258065], dtype=float64)}

# gradients of "postfit" model:
postfit_nll = NLL(model=model.update(values=values), observation=jnp.array([64.0]))
fast_grad_nll_postfit = eqx.filter_grad(eqx.filter_jit(postfit_nll))
print(fast_grad_nll_postfit({"sigma": jnp.array([0.2])}))
# -> {'sigma': Array([0.5030303], dtype=float64)}

Contributing

See CONTRIBUTING.md for instructions on how to contribute.

License

Distributed under the terms of the BSD license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dilax-0.1.3.tar.gz (21.8 kB view details)

Uploaded Source

Built Distribution

dilax-0.1.3-py3-none-any.whl (15.2 kB view details)

Uploaded Python 3

File details

Details for the file dilax-0.1.3.tar.gz.

File metadata

  • Download URL: dilax-0.1.3.tar.gz
  • Upload date:
  • Size: 21.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.6

File hashes

Hashes for dilax-0.1.3.tar.gz
Algorithm Hash digest
SHA256 1240e535949f42dc1d988fb6b3559683539267ac466d2e6c429f9f2d18f47d00
MD5 55e4fac442de307850b4ef355b122eb7
BLAKE2b-256 e78e021636139d308aef68f85ec00e9d570dda2a1e9d2c31334a493712088fb1

See more details on using hashes here.

File details

Details for the file dilax-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: dilax-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 15.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.6

File hashes

Hashes for dilax-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 36f5750f0c0a644a6d511aa4a93d9c4225b9dc657d4b43d8043129a1110f5a83
MD5 31893a567217d266c8797c4063ae6aa4
BLAKE2b-256 fd8829908992c47c99db41700ba99087356530ede94db54dea494c900d5a98f4

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page