Skip to main content

Automatic differentiation for high-energy physics correction factor calculations.

Project description

correctionlib-gradients

ci codecov pre-commit.ci status code style: black PyPI - Version PyPI - Python Version

A JAX-friendly, auto-differentiable, Python-only implementation of correctionlib correction evaluations.


Table of Contents

Installation

pip install correctionlib-gradients

Usage

  1. construct a CorrectionWithGradient object from a correctionlib.schemav2.Correction
  2. there is no point 2: you can use CorrectionWithGradient.evaluate as a normal JAX-friendly, auto-differentiable function

Example

import jax
import jax.numpy as jnp

from correctionlib import schemav2
from correctionlib_gradients import CorrectionWithGradient

# given a correctionlib schema:
formula_schema = schemav2.Correction(
    name="x squared",
    version=2,
    inputs=[schemav2.Variable(name="x", type="real")],
    output=schemav2.Variable(name="a scale", type="real"),
    data=schemav2.Formula(
        nodetype="formula",
        expression="x * x",
        parser="TFormula",
        variables=["x"],
    ),
)

# construct a CorrectionWithGradient
c = CorrectionWithGradient(formula_schema)

# use c.evaluate as a JAX-friendly, auto-differentiable function
value, grad = jax.value_and_grad(c.evaluate)(3.0)
assert jnp.isclose(value, 9.0)
assert jnp.isclose(grad, 6.0)

# for Formula corrections, jax.jit and jax.vmap work too
xs = jnp.array([3.0, 4.0])
values, grads = jax.vmap(jax.jit(jax.value_and_grad(c.evaluate)))(xs)
assert jnp.allclose(values, jnp.array([9.0, 16.0]))
assert jnp.allclose(grads, jnp.array([6.0, 8.0]))

Supported types of corrections

Currently the following corrections from correctionlib.schemav2 are supported:

  • Formula, including parametrical formulas
  • Binning with uniform or non-uniform bin edges and flow="clamp"; bin contents can be either:
    • all scalar values
    • all Formula or FormulaRef
  • scalar constants

Known limitations

Only the evaluation of Formula corrections is fully JAX traceable.

For other corrections, e.g. Binning, gradients can be computed (jax.grad works) but as JAX cannot trace the computation utilities such as jax.jit and jax.vmap will not work. np.vectorize can be used as an alternative to jax.vmap in these cases.

License

correctionlib-gradients is distributed under the terms of the BSD 3-Clause license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

correctionlib_gradients-0.2.2.tar.gz (14.8 kB view hashes)

Uploaded Source

Built Distribution

correctionlib_gradients-0.2.2-py3-none-any.whl (11.6 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page