Numerical quadrature with JAX
Project description
quadax is a library for numerical quadrature and integration using JAX.
vmap-able, jit-able, differentiable.
Scalar or vector valued integrands.
Finite or infinite domains with discontinuities or singularities within the domain of integration.
Globally adaptive Gauss-Konrod and Clenshaw-Curtis quadrature for smooth integrands (similar to scipy.integrate.quad)
Adaptive tanh-sinh quadrature for singular or near singular integrands.
Quadrature from sampled values using trapezoidal and Simpsons methods.
Coming soon:
Custom JVP/VJP rules (currently AD works by differentiating the loop which isn’t the most efficient.)
N-D quadrature (cubature)
QMC methods
Integration with weight functions
Sparse grids (maybe, need to play with data structures and JAX)
Installation
quadax is installable with pip:
pip install quadax
Usage
import jax.numpy as jnp
import numpy as np
from quadax import quadgk
fun = lambda t: t * jnp.log(1 + t)
epsabs = epsrel = 1e-5 # by default jax uses 32 bit, higher accuracy requires going to 64 bit
a, b = 0, 1
y, info = quadgk(fun, [a, b], epsabs=epsabs, epsrel=epsrel)
assert info.err < max(epsabs, epsrel*abs(y))
np.testing.assert_allclose(y, 1/4, rtol=epsrel, atol=epsabs)
For full details of various options see the API documentation
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.