Skip to main content

Automatic differentiation for high-energy physics correction factor calculations.

Project description

correctionlib-gradients

ci codecov pre-commit.ci status code style: black PyPI - Version PyPI - Python Version

A JAX-friendly, auto-differentiable, Python-only implementation of correctionlib correction evaluations.


Table of Contents

Installation

pip install correctionlib-gradients

Usage

  1. construct a CorrectionWithGradient object from a correctionlib.schemav2.Correction
  2. there is no point 2: you can use CorrectionWithGradient.evaluate as a normal JAX-friendly, auto-differentiable function

Example

import jax
import jax.numpy as jnp

from correctionlib import schemav2
from correctionlib_gradients import CorrectionWithGradient

# given a correctionlib schema:
formula_schema = schemav2.Correction(
    name="x squared",
    version=2,
    inputs=[schemav2.Variable(name="x", type="real")],
    output=schemav2.Variable(name="a scale", type="real"),
    data=schemav2.Formula(
        nodetype="formula",
        expression="x * x",
        parser="TFormula",
        variables=["x"],
    ),
)

# construct a CorrectionWithGradient
c = CorrectionWithGradient(formula_schema)

# use c.evaluate as a JAX-friendly, auto-differentiable function
value, grad = jax.value_and_grad(c.evaluate)(3.0)
assert jnp.isclose(value, 9.0)
assert jnp.isclose(grad, 6.0)

# for Formula corrections, jax.jit and jax.vmap work too
xs = jnp.array([3.0, 4.0])
values, grads = jax.vmap(jax.jit(jax.value_and_grad(c.evaluate)))(xs)
assert jnp.allclose(values, jnp.array([9.0, 16.0]))
assert jnp.allclose(grads, jnp.array([6.0, 8.0]))

Supported types of corrections

Currently the following corrections from correctionlib.schemav2 are supported:

  • Formula, including parametrical formulas
  • Binning with uniform or non-uniform bin edges and flow="clamp"; bin contents can be either:
    • all scalar values
    • all Formula or FormulaRef
  • scalar constants

Known limitations

Only the evaluation of Formula corrections is fully JAX traceable.

For other corrections, e.g. Binning, gradients can be computed (jax.grad works) but as JAX cannot trace the computation utilities such as jax.jit and jax.vmap will not work. np.vectorize can be used as an alternative to jax.vmap in these cases.

License

correctionlib-gradients is distributed under the terms of the BSD 3-Clause license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

correctionlib_gradients-0.2.2.tar.gz (14.8 kB view details)

Uploaded Source

Built Distribution

correctionlib_gradients-0.2.2-py3-none-any.whl (11.6 kB view details)

Uploaded Python 3

File details

Details for the file correctionlib_gradients-0.2.2.tar.gz.

File metadata

File hashes

Hashes for correctionlib_gradients-0.2.2.tar.gz
Algorithm Hash digest
SHA256 fa994ee9706cacde516bafbdd2b6c09364611ee05d4912e4298bec5227d4b261
MD5 c3e7b470da229fd4a026a153ca7f65c3
BLAKE2b-256 42478ebd2dcfb33ab4207a1414d00768dac42716d32377f07aa560ad448783c5

See more details on using hashes here.

File details

Details for the file correctionlib_gradients-0.2.2-py3-none-any.whl.

File metadata

File hashes

Hashes for correctionlib_gradients-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 0adbf6a9744a4f9e4437a6226f312007004fa33a6576447d2ae7e16153b93add
MD5 f5252a4a172f8551a84464097e0d2718
BLAKE2b-256 160b7d479fcfa893b6eff24044df689f5391a962a9edb9d02d10a7085f930d30

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page