Skip to main content

diffrax extras: OOP and vectorization

Project description

diffraxtra

diffrax extras

PyPI: diffraxtra PyPI versions: diffraxtra diffraxtra license

CI status codecov ruff ruff pre-commit


Extras for diffrax.

  • DiffEqSolver: an object-oriented interface to diffrax.diffeqsolve.
  • VectorizedDenseInterpolation: a vectorized form of diffrax.DenseInterpolation that works on batched results from diffrax.diffeqsolve.

For example,

import jax.numpy as jnp
import diffrax as dfx
from diffraxtra import DiffEqSolver

# Construct a solver object.
solver = DiffEqSolver(dfx.Dopri5(),
                      stepsize_controller=dfx.PIDController(rtol=1e-5, atol=1e-5))

# And a differential equation to solve.
term = dfx.ODETerm(lambda t, y, args: -y)

# Then solve the differential equation.
saveat = dfx.SaveAt(t1=True, dense=True)
soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
              vectorize_interpolation=True)

print(soln)
# Solution(
#   t0=f32[], t1=f32[], ts=f32[1],
#   ys=f32[1],
#   interpolation=VectorizedDenseInterpolation(
#     scalar_interpolation=DenseInterpolation( ... ),
#     batch_shape=(),
#     y0_shape=()
#   ),
#   ...
# )

soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))
# Array([[0.90483742, 0.81872516],
#         [0.74080871, 0.67031456]], dtype=float64)

Installation

PyPI platforms PyPI version

pip install diffraxtra

Documentation

DiffEqSolver

>>> import jax.numpy as jnp
>>> import diffrax as dfx
>>> from diffraxtra import DiffEqSolver

Construct a solver object.

>>> solver = DiffEqSolver(dfx.Dopri5(),
...                stepsize_controller=dfx.PIDController(rtol=1e-5, atol=1e-5))

And a differential equation to solve.

>>> term = dfx.ODETerm(lambda t, y, args: -y)

Then solve the differential equation.

>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1)
>>> soln
Solution( t0=f64[], t1=f64[], ts=f64[1],
          ys=f64[1], ... )

The solution can be saved at specific times.

>>> saveat = dfx.SaveAt(ts=[0., 1., 2., 3.])
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)
>>> soln
Solution( t0=f64[], t1=f64[], ts=f64[4],
          ys=f64[4], ... )

The solution can be densely interpolated.

>>> saveat = dfx.SaveAt(t1=True, dense=True)
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)
>>> soln
Solution( t0=f64[], t1=f64[], ts=f64[1],
          ys=f64[1], ... )
>>> soln.evaluate(0.5).round(3)
Array(0.607, dtype=float64)

Using the VectorizedDenseInterpolation class, the interpolation can be vectorized, enabling evaluation of batched solutions over batches of times.

>>> from diffraxtra import VectorizedDenseInterpolation
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)
>>> soln = VectorizedDenseInterpolation.apply_to_solution(soln)
>>> soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))
Array([[0.90483742, 0.81872516],
       [0.74080871, 0.67031456]], dtype=float64)

This can be more conveniently done using the vectorize_interpolation argument.

>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
...               vectorize_interpolation=True)
>>> soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))
Array([[0.90483742, 0.81872516],
       [0.74080871, 0.67031456]], dtype=float64)

There are many ways to construct a DiffEqSolver object. For example, we can can make a new one from an existing DiffEqSolver object

>>> solver = DiffEqSolver(dfx.Dopri5())
>>> DiffEqSolver.from_(solver) is solver
True

From a diffrax.AbstractSolver object.

>>> solver = DiffEqSolver.from_(dfx.Dopri5())
>>> solver
DiffEqSolver(
  solver=Dopri5(scan_kind=None),
  stepsize_controller=ConstantStepSize(),
  adjoint=RecursiveCheckpointAdjoint(checkpoints=None)
)

From a collections.abc.Mapping

>>> solver = DiffEqSolver.from_({"solver": dfx.Dopri5(),
...       "stepsize_controller": dfx.PIDController(rtol=1e-5, atol=1e-5)})
>>> solver
DiffEqSolver(
  solver=Dopri5(scan_kind=None),
  stepsize_controller=PIDController( ... ),
  adjoint=RecursiveCheckpointAdjoint(checkpoints=None)
)

For a full enumeration of the ways to construct a DiffEqSolver object, see galax.dynamics.integrate.DiffEqSolver.from_.

VectorizedDenseInterpolation

Vectorized wrapper around a diffrax.DenseInterpolation

This also works on non-batched interpolations.

>>> import jax
>>> import jax.numpy as jnp
>>> import diffrax as dfx

We'll start with a non-batched interpolation:

>>> vector_field = lambda t, y, args: -y
>>> term = dfx.ODETerm(vector_field)
>>> solver = dfx.Dopri5()
>>> ts = jnp.array([0.0, 1, 2, 3])
>>> saveat = dfx.SaveAt(ts=ts, dense=True)
>>> stepsize_controller = dfx.PIDController(rtol=1e-5, atol=1e-5)

>>> sol = dfx.diffeqsolve(
...     term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
...     stepsize_controller=stepsize_controller)
>>> interp = VectorizedDenseInterpolation(sol.interpolation)
>>> interp
VectorizedDenseInterpolation(
  scalar_interpolation=DenseInterpolation(
    ts=f64[1,4097],
    ts_size=weak_i64[1],
    infos={'k': f64[1,4096,7], 'y0': f64[1,4096], 'y1': f64[1,4096]},
    interpolation_cls=<class 'diffrax._solver.dopri5._Dopri5Interpolation'>,
    direction=weak_i64[1],
    t0_if_trivial=f64[1],
    y0_if_trivial=f64[1]
  ),
  batch_shape=(),
  y0_shape=()
)

This can be evaluated by the normal means:

>>> interp.evaluate(ts[-1])  # scalar evaluation
Array(0.04978961, dtype=float64)

It also works on arrays, without needed to manually apply jax.vmap:

>>> interp.evaluate(ts)  # It works on arrays!
Array([1. , 0.36788338, 0.13533922, 0.04978961], dtype=float64)
>>> interp.evaluate(ts, ts[0])  # t1 - t0 mixed scalar and array
Array([0. , 0.63211662, 0.86466078, 0.95021039], dtype=float64)

Better yet, the time array may be arbitrarily shaped:

>>> interp.evaluate(ts.reshape(2, 2)).round(3)
Array([[1.   , 0.368],
       [0.135, 0.05 ]], dtype=float64)

As a convenience, we can also apply the VectorizedDenseInterpolation to the solution to modify the interpolation "in-place" (when in a jitted context, otherwise out-of-place, returning a copy):

>>> sol = VectorizedDenseInterpolation.apply_to_solution(sol)
>>> isinstance(sol, dfx.Solution)
True
>>> isinstance(sol.interpolation, VectorizedDenseInterpolation)
True

Now we'll batch the interpolation:

>>> @jax.vmap
... def solve(y0):
...     sol = dfx.diffeqsolve(
...         term, solver, t0=0, t1=3, dt0=0.1, y0=y0, saveat=saveat,
...         stepsize_controller=stepsize_controller)
...     return sol
>>> sol = solve(jnp.array([1, 2, 3]))
>>> interp = VectorizedDenseInterpolation(sol.interpolation)
>>> interp.evaluate(ts[-1]).round(3)  # scalar eval of batched interp
Array([0.05 , 0.1  , 0.149], dtype=float64)
>>> interp.evaluate(ts).astype(jnp.float64).round(3)  # array eval of batched interp
Array([[1.   , 0.368, 0.135, 0.05 ],
       [2.   , 0.736, 0.271, 0.1  ],
       [3.   , 1.104, 0.406, 0.149]], dtype=float64)
>>> interp.evaluate(ts, ts[0]).round(3)  # mixed scalar and array eval
Array([[0.   , 0.632, 0.865, 0.95 ],
       [0.   , 1.264, 1.729, 1.9  ],
       [0.   , 1.896, 2.594, 2.851]], dtype=float64)
>>> ys = interp.evaluate(ts.reshape(2, 2)).round(3)  # arbitrary shape eval
>>> ys
Array([[[1.   , 0.368],
        [0.135, 0.05 ]],
        [[2.   , 0.736],
        [0.271, 0.1  ]],
        [[3.   , 1.104],
        [0.406, 0.149]]], dtype=float64)
>>> ys.shape  # (batch, *times)
(3, 2, 2)

Citation

DOI

If you enjoyed using this library and would like to cite the software you use then click the link above.

Development

Actions Status codecov SPEC 0 — Minimum Supported Dependencies pre-commit ruff

We welcome contributions!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

diffraxtra-1.0.2.tar.gz (71.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

diffraxtra-1.0.2-py3-none-any.whl (13.5 kB view details)

Uploaded Python 3

File details

Details for the file diffraxtra-1.0.2.tar.gz.

File metadata

  • Download URL: diffraxtra-1.0.2.tar.gz
  • Upload date:
  • Size: 71.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for diffraxtra-1.0.2.tar.gz
Algorithm Hash digest
SHA256 0478fca1268109dcaf8c988f89945ed31fe2bdbe5accb3bbbef9e3f3e1f78f08
MD5 2199335fff596116a3ba97b3b9e188ea
BLAKE2b-256 1920e8931816281e6c10adac6341d5e373bb3db778b82c038da8227787f12cee

See more details on using hashes here.

Provenance

The following attestation bundles were made for diffraxtra-1.0.2.tar.gz:

Publisher: cd.yml on GalacticDynamics/diffraxtra

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file diffraxtra-1.0.2-py3-none-any.whl.

File metadata

  • Download URL: diffraxtra-1.0.2-py3-none-any.whl
  • Upload date:
  • Size: 13.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for diffraxtra-1.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 f4966c25b960dc361ce5141fac489d2cb52fa618eb1e3e3cd85d68f472e3f7a4
MD5 e11041f32827b0963445ad2098570958
BLAKE2b-256 5faa93cec938d3a5a6315b88cf6791ea1655585c755dcb0742040362b3b55622

See more details on using hashes here.

Provenance

The following attestation bundles were made for diffraxtra-1.0.2-py3-none-any.whl:

Publisher: cd.yml on GalacticDynamics/diffraxtra

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page