Automatic differentiation for high-energy physics correction factor calculations.
Project description
correctionlib-gradients
A JAX-friendly, auto-differentiable, Python-only implementation of correctionlib correction evaluations.
Table of Contents
Installation
pip install correctionlib-gradients
Usage
- construct a
CorrectionWithGradientobject from acorrectionlib.schemav2.Correction - there is no point 2: you can use
CorrectionWithGradient.evaluateas a normal JAX-friendly, auto-differentiable function
Example
import jax
import jax.numpy as jnp
from correctionlib import schemav2
from correctionlib_gradients import CorrectionWithGradient
# given a correctionlib schema:
formula_schema = schemav2.Correction(
name="x squared",
version=2,
inputs=[schemav2.Variable(name="x", type="real")],
output=schemav2.Variable(name="a scale", type="real"),
data=schemav2.Formula(
nodetype="formula",
expression="x * x",
parser="TFormula",
variables=["x"],
),
)
# construct a CorrectionWithGradient
c = CorrectionWithGradient(formula_schema)
# use c.evaluate as a JAX-friendly, auto-differentiable function
value, grad = jax.value_and_grad(c.evaluate)(3.0)
assert jnp.isclose(value, 9.0)
assert jnp.isclose(grad, 6.0)
# for Formula corrections, jax.jit and jax.vmap work too
xs = jnp.array([3.0, 4.0])
values, grads = jax.vmap(jax.jit(jax.value_and_grad(c.evaluate)))(xs)
assert jnp.allclose(values, jnp.array([9.0, 16.0]))
assert jnp.allclose(grads, jnp.array([6.0, 8.0]))
Supported types of corrections
Currently the following corrections from correctionlib.schemav2 are supported:
Formula, including parametrical formulasBinningwith uniform or non-uniform bin edges andflow="clamp"; bin contents can be either:- all scalar values
- all
FormulaorFormulaRef
- scalar constants
Known limitations
Only the evaluation of Formula corrections is fully JAX traceable.
For other corrections, e.g. Binning, gradients can be computed (jax.grad works) but as JAX cannot
trace the computation utilities such as jax.jit and jax.vmap will not work.
np.vectorize can be used as an alternative to jax.vmap in these cases.
License
correctionlib-gradients is distributed under the terms of the BSD 3-Clause license.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file correctionlib_gradients-0.2.2.tar.gz.
File metadata
- Download URL: correctionlib_gradients-0.2.2.tar.gz
- Upload date:
- Size: 14.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: python-httpx/0.25.1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
fa994ee9706cacde516bafbdd2b6c09364611ee05d4912e4298bec5227d4b261
|
|
| MD5 |
c3e7b470da229fd4a026a153ca7f65c3
|
|
| BLAKE2b-256 |
42478ebd2dcfb33ab4207a1414d00768dac42716d32377f07aa560ad448783c5
|
File details
Details for the file correctionlib_gradients-0.2.2-py3-none-any.whl.
File metadata
- Download URL: correctionlib_gradients-0.2.2-py3-none-any.whl
- Upload date:
- Size: 11.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: python-httpx/0.25.1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0adbf6a9744a4f9e4437a6226f312007004fa33a6576447d2ae7e16153b93add
|
|
| MD5 |
f5252a4a172f8551a84464097e0d2718
|
|
| BLAKE2b-256 |
160b7d479fcfa893b6eff24044df689f5391a962a9edb9d02d10a7085f930d30
|