Skip to main content

Differentiable (binned) likelihoods in JAX.

Project description

dilax

Documentation Status Actions Status PyPI version PyPI platforms

Differentiable (binned) likelihoods in JAX.

Installation

python -m pip install dilax

From source:

git clone https://github.com/pfackeldey/dilax
cd dilax
python -m pip install .

Usage - Model definition and fitting

See more in examples/

dilax in a nutshell:

import jax
import jax.numpy as jnp
import equinox as eqx

from dilax.likelihood import NLL
from dilax.model import Model, Result
from dilax.optimizer import JaxOptimizer
from dilax.parameter import Parameter, gauss, modifier, unconstrained


jax.config.update("jax_enable_x64", True)


# define a simple model with two processes and two parameters
class MyModel(Model):
    def __call__(self, processes: dict, parameters: dict[str, Parameter]) -> Result:
        res = Result()

        # signal
        mu_mod = modifier(name="mu", parameter=parameters["mu"], effect=unconstrained())
        res.add(process="signal", expectation=mu_mod(processes["signal"]))

        # background
        bkg_mod = modifier(name="sigma", parameter=parameters["sigma"], effect=gauss(0.2))
        res.add(process="background", expectation=bkg_mod(processes["background"]))
        return res


# setup model
processes = {"signal": jnp.array([10.0]), "background": jnp.array([50.0])}
parameters = {
    "mu": Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)),
    "sigma": Parameter(value=jnp.array([0.0])),
}
model = MyModel(processes=processes, parameters=parameters)

# define negative log-likelihood with data (observation)
nll = NLL(model=model, observation=jnp.array([64.0]))
# jit it!
fast_nll = eqx.filter_jit(nll)

# setup fit: initial values of parameters and a suitable optimizer
init_values = model.parameter_values
optimizer = JaxOptimizer.make(name="ScipyMinimize", settings={"method": "trust-constr"})

# fit
values, state = optimizer.fit(fun=fast_nll, init_values=init_values)

print(values)
# -> {'mu': Array([1.4], dtype=float64),
#     'sigma': Array([4.04723836e-14], dtype=float64)}

# eval model with fitted values/parameters
print(model.update(values=values).evaluate().expectation())
# -> Array([64.], dtype=float64)


# gradients of "prefit" model:
fast_grad_nll_prefit = eqx.filter_grad(nll)
print(fast_grad_nll_prefit({"sigma": jnp.array([0.2])}))
# -> {'sigma': Array([-0.12258065], dtype=float64)}

# gradients of "postfit" model:
postfit_nll = NLL(model=model.update(values=values), observation=jnp.array([64.0]))
fast_grad_nll_postfit = eqx.filter_grad(eqx.filter_jit(postfit_nll))
print(fast_grad_nll_postfit({"sigma": jnp.array([0.2])}))
# -> {'sigma': Array([0.5030303], dtype=float64)}

Contributing

See CONTRIBUTING.md for instructions on how to contribute.

License

Distributed under the terms of the BSD license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dilax-0.1.4.tar.gz (21.7 kB view details)

Uploaded Source

Built Distribution

dilax-0.1.4-py3-none-any.whl (15.1 kB view details)

Uploaded Python 3

File details

Details for the file dilax-0.1.4.tar.gz.

File metadata

  • Download URL: dilax-0.1.4.tar.gz
  • Upload date:
  • Size: 21.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.6

File hashes

Hashes for dilax-0.1.4.tar.gz
Algorithm Hash digest
SHA256 9191c2d1abd50e0a73e17518469111b31c8a773ddd8c39c378ed6944429d2e54
MD5 48d420a8c4a2d82e16c650428430434f
BLAKE2b-256 4ec05e8f8d9c42fb1f103caaae616a682af2199cc303b9db449b79417142d0eb

See more details on using hashes here.

File details

Details for the file dilax-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: dilax-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 15.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.6

File hashes

Hashes for dilax-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 f286c86ab2d1af472cd961275abe92c79d9dd0796e43ec1d4394f1f5e79dba75
MD5 53a667f6879823d6abf755a0f75b0104
BLAKE2b-256 398b250c9da5bcb979f025c9cd563973e651f971c73b6bd4fed3c851f60113a5

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page