Skip to main content

diffrax extras: OOP and vectorization

Project description

diffraxtra

diffrax extras

PyPI: diffraxtra PyPI versions: diffraxtra diffraxtra license

CI status codecov ruff ruff pre-commit


Extras for diffrax.

  • DiffEqSolver: an object-oriented interface to diffrax.diffeqsolve.
  • VectorizedDenseInterpolation: a vectorized form of diffrax.DenseInterpolation that works on batched results from diffrax.diffeqsolve.

For example,

import jax.numpy as jnp
import diffrax as dfx
from diffraxtra import DiffEqSolver

# Construct a solver object.
solver = DiffEqSolver(dfx.Dopri5(),
                      stepsize_controller=dfx.PIDController(rtol=1e-5, atol=1e-5))

# And a differential equation to solve.
term = dfx.ODETerm(lambda t, y, args: -y)

# Then solve the differential equation.
saveat = dfx.SaveAt(t1=True, dense=True)
soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
              vectorize_interpolation=True)

print(soln)
# Solution(
#   t0=f32[], t1=f32[], ts=f32[1],
#   ys=f32[1],
#   interpolation=VectorizedDenseInterpolation(
#     scalar_interpolation=DenseInterpolation( ... ),
#     batch_shape=(),
#     y0_shape=()
#   ),
#   ...
# )

soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))
# Array([[0.90483742, 0.81872516],
#         [0.74080871, 0.67031456]], dtype=float64)

Installation

PyPI platforms PyPI version

pip install diffraxtra

Documentation

DiffEqSolver

>>> import jax.numpy as jnp
>>> import diffrax as dfx
>>> from diffraxtra import DiffEqSolver

Construct a solver object.

>>> solver = DiffEqSolver(dfx.Dopri5(),
...                stepsize_controller=dfx.PIDController(rtol=1e-5, atol=1e-5))

And a differential equation to solve.

>>> term = dfx.ODETerm(lambda t, y, args: -y)

Then solve the differential equation.

>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1)
>>> soln
Solution( t0=f64[], t1=f64[], ts=f64[1],
          ys=f64[1], ... )

The solution can be saved at specific times.

>>> saveat = dfx.SaveAt(ts=[0., 1., 2., 3.])
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)
>>> soln
Solution( t0=f64[], t1=f64[], ts=f64[4],
          ys=f64[4], ... )

The solution can be densely interpolated.

>>> saveat = dfx.SaveAt(t1=True, dense=True)
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)
>>> soln
Solution( t0=f64[], t1=f64[], ts=f64[1],
          ys=f64[1], ... )
>>> soln.evaluate(0.5).round(3)
Array(0.607, dtype=float64)

Using the VectorizedDenseInterpolation class, the interpolation can be vectorized, enabling evaluation of batched solutions over batches of times.

>>> from diffraxtra import VectorizedDenseInterpolation
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)
>>> soln = VectorizedDenseInterpolation.apply_to_solution(soln)
>>> soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))
Array([[0.90483742, 0.81872516],
       [0.74080871, 0.67031456]], dtype=float64)

This can be more conveniently done using the vectorize_interpolation argument.

>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
...               vectorize_interpolation=True)
>>> soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))
Array([[0.90483742, 0.81872516],
       [0.74080871, 0.67031456]], dtype=float64)

There are many ways to construct a DiffEqSolver object. For example, we can can make a new one from an existing DiffEqSolver object

>>> solver = DiffEqSolver(dfx.Dopri5())
>>> DiffEqSolver.from_(solver) is solver
True

From a diffrax.AbstractSolver object.

>>> solver = DiffEqSolver.from_(dfx.Dopri5())
>>> solver
DiffEqSolver(
  solver=Dopri5(scan_kind=None),
  stepsize_controller=ConstantStepSize(),
  adjoint=RecursiveCheckpointAdjoint(checkpoints=None),
  max_steps=4096
)

From a collections.abc.Mapping

>>> solver = DiffEqSolver.from_({"solver": dfx.Dopri5(),
...       "stepsize_controller": dfx.PIDController(rtol=1e-5, atol=1e-5)})
>>> solver
DiffEqSolver(
  solver=Dopri5(scan_kind=None),
  stepsize_controller=PIDController( ... ),
  adjoint=RecursiveCheckpointAdjoint(checkpoints=None),
  max_steps=4096
)

For a full enumeration of the ways to construct a DiffEqSolver object, see diffraxtra.DiffEqSolver.from_.

VectorizedDenseInterpolation

Vectorized wrapper around a diffrax.DenseInterpolation

This also works on non-batched interpolations.

>>> import jax
>>> import jax.numpy as jnp
>>> import diffrax as dfx

We'll start with a non-batched interpolation:

>>> vector_field = lambda t, y, args: -y
>>> term = dfx.ODETerm(vector_field)
>>> solver = dfx.Dopri5()
>>> ts = jnp.array([0.0, 1, 2, 3])
>>> saveat = dfx.SaveAt(ts=ts, dense=True)
>>> stepsize_controller = dfx.PIDController(rtol=1e-5, atol=1e-5)

>>> sol = dfx.diffeqsolve(
...     term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
...     stepsize_controller=stepsize_controller)
>>> interp = VectorizedDenseInterpolation(sol.interpolation)
>>> interp
VectorizedDenseInterpolation(
  scalar_interpolation=DenseInterpolation(
    ts=f64[1,4097],
    ts_size=weak_i64[1],
    infos={'k': f64[1,4096,7], 'y0': f64[1,4096], 'y1': f64[1,4096]},
    interpolation_cls=<class 'diffrax._solver.dopri5._Dopri5Interpolation'>,
    direction=weak_i64[1],
    t0_if_trivial=f64[1],
    y0_if_trivial=f64[1]
  ),
  batch_shape=(),
  y0_shape=()
)

This can be evaluated by the normal means:

>>> interp.evaluate(ts[-1])  # scalar evaluation
Array(0.04978961, dtype=float64)

It also works on arrays, without needed to manually apply jax.vmap:

>>> interp.evaluate(ts)  # It works on arrays!
Array([1. , 0.36788338, 0.13533922, 0.04978961], dtype=float64)
>>> interp.evaluate(ts, ts[0])  # t1 - t0 mixed scalar and array
Array([0. , 0.63211662, 0.86466078, 0.95021039], dtype=float64)

Better yet, the time array may be arbitrarily shaped:

>>> interp.evaluate(ts.reshape(2, 2)).round(3)
Array([[1.   , 0.368],
       [0.135, 0.05 ]], dtype=float64)

As a convenience, we can also apply the VectorizedDenseInterpolation to the solution to modify the interpolation "in-place" (when in a jitted context, otherwise out-of-place, returning a copy):

>>> sol = VectorizedDenseInterpolation.apply_to_solution(sol)
>>> isinstance(sol, dfx.Solution)
True
>>> isinstance(sol.interpolation, VectorizedDenseInterpolation)
True

Now we'll batch the interpolation:

>>> @jax.vmap
... def solve(y0):
...     sol = dfx.diffeqsolve(
...         term, solver, t0=0, t1=3, dt0=0.1, y0=y0, saveat=saveat,
...         stepsize_controller=stepsize_controller)
...     return sol
>>> sol = solve(jnp.array([1, 2, 3]))
>>> interp = VectorizedDenseInterpolation(sol.interpolation)
>>> interp.evaluate(ts[-1]).round(3)  # scalar eval of batched interp
Array([0.05 , 0.1  , 0.149], dtype=float64)
>>> interp.evaluate(ts).astype(jnp.float64).round(3)  # array eval of batched interp
Array([[1.   , 0.368, 0.135, 0.05 ],
       [2.   , 0.736, 0.271, 0.1  ],
       [3.   , 1.104, 0.406, 0.149]], dtype=float64)
>>> interp.evaluate(ts, ts[0]).round(3)  # mixed scalar and array eval
Array([[0.   , 0.632, 0.865, 0.95 ],
       [0.   , 1.264, 1.729, 1.9  ],
       [0.   , 1.896, 2.594, 2.851]], dtype=float64)
>>> ys = interp.evaluate(ts.reshape(2, 2)).round(3)  # arbitrary shape eval
>>> ys
Array([[[1.   , 0.368],
        [0.135, 0.05 ]],
        [[2.   , 0.736],
        [0.271, 0.1  ]],
        [[3.   , 1.104],
        [0.406, 0.149]]], dtype=float64)
>>> ys.shape  # (batch, *times)
(3, 2, 2)

Citation

DOI

If you enjoyed using this library and would like to cite the software you use then click the link above.

Development

Actions Status codecov SPEC 0 — Minimum Supported Dependencies pre-commit ruff

We welcome contributions!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

diffraxtra-1.2.0.tar.gz (72.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

diffraxtra-1.2.0-py3-none-any.whl (14.2 kB view details)

Uploaded Python 3

File details

Details for the file diffraxtra-1.2.0.tar.gz.

File metadata

  • Download URL: diffraxtra-1.2.0.tar.gz
  • Upload date:
  • Size: 72.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for diffraxtra-1.2.0.tar.gz
Algorithm Hash digest
SHA256 9e50d4099236e5b794f42b1469a79521de929aa9d7344f7ecca95e4f5041cf30
MD5 2196b9f83d90ec840e3cfd7de829aa15
BLAKE2b-256 e0613d7aa0ff9d4a35cdfb0c030418e2463b4e447f29a16049df0207a4d82c1e

See more details on using hashes here.

Provenance

The following attestation bundles were made for diffraxtra-1.2.0.tar.gz:

Publisher: cd.yml on GalacticDynamics/diffraxtra

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file diffraxtra-1.2.0-py3-none-any.whl.

File metadata

  • Download URL: diffraxtra-1.2.0-py3-none-any.whl
  • Upload date:
  • Size: 14.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for diffraxtra-1.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 e7af08480e1f7159bedbf15323fed6bcf892d415d5bdbbd215470a23cbfffaa4
MD5 7ec56b45ed7ae9cbed80a2b841fbc5e5
BLAKE2b-256 db653d8e734ce5be0d4451e7a4c1dae0d0d973f8197e59a327183ace3589d203

See more details on using hashes here.

Provenance

The following attestation bundles were made for diffraxtra-1.2.0-py3-none-any.whl:

Publisher: cd.yml on GalacticDynamics/diffraxtra

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page