Skip to main content

Differentiable (binned) likelihoods in JAX.

Project description

dilax

Documentation Status Actions Status PyPI version PyPI platforms

Differentiable (binned) likelihoods in JAX.

Installation

python -m pip install dilax

From source:

git clone https://github.com/pfackeldey/dilax
cd dilax
python -m pip install .

Usage - Model definition and fitting

See more in examples/

dilax in a nutshell:

import equinox as eqx
import jax
import jax.numpy as jnp

import dilax as dlx

jax.config.update("jax_enable_x64", True)


# define a simple model with two processes and two parameters
class MyModel(dlx.Model):
    def __call__(self, processes: dict, parameters: dict) -> dlx.Result:
        res = dlx.Result()

        # signal
        mu_mod = dlx.modifier(
            name="mu", parameter=parameters["mu"], effect=dlx.effect.unconstrained()
        )
        res.add(process="signal", expectation=mu_mod(processes["signal"]))

        # background
        bkg_mod = dlx.modifier(
            name="sigma", parameter=parameters["sigma"], effect=dlx.effect.gauss(0.2)
        )
        res.add(process="background", expectation=bkg_mod(processes["background"]))
        return res


# setup model
processes = {"signal": jnp.array([10.0]), "background": jnp.array([50.0])}
parameters = {
    "mu": dlx.Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)),
    "sigma": dlx.Parameter(value=jnp.array([0.0])),
}
model = MyModel(processes=processes, parameters=parameters)

# define negative log-likelihood with data (observation)
nll = dlx.likelihood.NLL(model=model, observation=jnp.array([64.0]))
# jit it!
fast_nll = eqx.filter_jit(nll)

# setup fit: initial values of parameters and a suitable optimizer
init_values = model.parameter_values
optimizer = dlx.optimizer.JaxOptimizer.make(
    name="ScipyMinimize", settings={"method": "trust-constr"}
)

# fit
values, state = optimizer.fit(fun=fast_nll, init_values=init_values)

print(values)
# -> {'mu': Array([1.4], dtype=float64),
#     'sigma': Array([4.04723836e-14], dtype=float64)}

# eval model with fitted values
print(model.update(values=values).evaluate().expectation())
# -> Array([64.], dtype=float64)


# gradients of "prefit" model:
print(eqx.filter_grad(nll)({"sigma": jnp.array([0.2])}))
# -> {'sigma': Array([-0.12258065], dtype=float64)}


# gradients of "postfit" model:
@eqx.filter_grad
@eqx.filter_jit
def grad_postfit_nll(where: dict[str, jax.Array]) -> dict[str, jax.Array]:
    nll = dlx.likelihood.NLL(
        model=model.update(values=values), observation=jnp.array([64.0])
    )
    return nll(where)


print(grad_postfit_nll({"sigma": jnp.array([0.2])}))
# -> {'sigma': Array([0.5030303], dtype=float64)}

Contributing

See CONTRIBUTING.md for instructions on how to contribute.

License

Distributed under the terms of the BSD license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dilax-0.1.6.tar.gz (25.1 kB view details)

Uploaded Source

Built Distribution

dilax-0.1.6-py3-none-any.whl (19.7 kB view details)

Uploaded Python 3

File details

Details for the file dilax-0.1.6.tar.gz.

File metadata

  • Download URL: dilax-0.1.6.tar.gz
  • Upload date:
  • Size: 25.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for dilax-0.1.6.tar.gz
Algorithm Hash digest
SHA256 8c1729fb1c111ec702434f5aaf3fbf5d5ea82bb62ef41b564511c600422fa6d4
MD5 c66e547464d8dea9d69fd184cb2f3510
BLAKE2b-256 22c3f0a369af40405228bf090e24a5cc68015c826d41a523747218944aa757b1

See more details on using hashes here.

File details

Details for the file dilax-0.1.6-py3-none-any.whl.

File metadata

  • Download URL: dilax-0.1.6-py3-none-any.whl
  • Upload date:
  • Size: 19.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for dilax-0.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 875bb19edb52e8a91aa5645f887bbef679567dd676fd800d7f1e0de17d60a166
MD5 2264dc526f1728c916a170ff6145012c
BLAKE2b-256 bed8687229535aa38a736f66d8a0e97f01580eab117456cae4bb929b278ad9c0

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page