Skip to main content

diffrax extras: OOP and vectorization

Project description

diffraxtra

diffrax extras

PyPI: diffraxtra PyPI versions: diffraxtra diffraxtra license

CI status codecov ruff ruff pre-commit


Extras for diffrax.

  • DiffEqSolver: an object-oriented interface to diffrax.diffeqsolve.
  • VectorizedDenseInterpolation: a vectorized form of diffrax.DenseInterpolation that works on batched results from diffrax.diffeqsolve.

For example,

import jax.numpy as jnp
import diffrax as dfx
from diffraxtra import DiffEqSolver

# Construct a solver object.
solver = DiffEqSolver(dfx.Dopri5(),
                      stepsize_controller=dfx.PIDController(rtol=1e-5, atol=1e-5))

# And a differential equation to solve.
term = dfx.ODETerm(lambda t, y, args: -y)

# Then solve the differential equation.
saveat = dfx.SaveAt(t1=True, dense=True)
soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
              vectorize_interpolation=True)

print(soln)
# Solution(
#   t0=f32[], t1=f32[], ts=f32[1],
#   ys=f32[1],
#   interpolation=VectorizedDenseInterpolation(
#     scalar_interpolation=DenseInterpolation( ... ),
#     batch_shape=(),
#     y0_shape=()
#   ),
#   ...
# )

soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))
# Array([[0.90483742, 0.81872516],
#         [0.74080871, 0.67031456]], dtype=float64)

Installation

PyPI platforms PyPI version

pip install diffraxtra

Documentation

DiffEqSolver

>>> import jax.numpy as jnp
>>> import diffrax as dfx
>>> from diffraxtra import DiffEqSolver

Construct a solver object.

>>> solver = DiffEqSolver(dfx.Dopri5(),
...                stepsize_controller=dfx.PIDController(rtol=1e-5, atol=1e-5))

And a differential equation to solve.

>>> term = dfx.ODETerm(lambda t, y, args: -y)

Then solve the differential equation.

>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1)
>>> soln
Solution( t0=f64[], t1=f64[], ts=f64[1],
          ys=f64[1], ... )

The solution can be saved at specific times.

>>> saveat = dfx.SaveAt(ts=[0., 1., 2., 3.])
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)
>>> soln
Solution( t0=f64[], t1=f64[], ts=f64[4],
          ys=f64[4], ... )

The solution can be densely interpolated.

>>> saveat = dfx.SaveAt(t1=True, dense=True)
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)
>>> soln
Solution( t0=f64[], t1=f64[], ts=f64[1],
          ys=f64[1], ... )
>>> soln.evaluate(0.5).round(3)
Array(0.607, dtype=float64)

Using the VectorizedDenseInterpolation class, the interpolation can be vectorized, enabling evaluation of batched solutions over batches of times.

>>> from diffraxtra import VectorizedDenseInterpolation
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)
>>> soln = VectorizedDenseInterpolation.apply_to_solution(soln)
>>> soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))
Array([[0.90483742, 0.81872516],
       [0.74080871, 0.67031456]], dtype=float64)

This can be more conveniently done using the vectorize_interpolation argument.

>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
...               vectorize_interpolation=True)
>>> soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))
Array([[0.90483742, 0.81872516],
       [0.74080871, 0.67031456]], dtype=float64)

There are many ways to construct a DiffEqSolver object. For example, we can can make a new one from an existing DiffEqSolver object

>>> solver = DiffEqSolver(dfx.Dopri5())
>>> DiffEqSolver.from_(solver) is solver
True

From a diffrax.AbstractSolver object.

>>> solver = DiffEqSolver.from_(dfx.Dopri5())
>>> solver
DiffEqSolver(
  solver=Dopri5(scan_kind=None),
  stepsize_controller=ConstantStepSize(),
  adjoint=RecursiveCheckpointAdjoint(checkpoints=None),
  event=None,
  max_steps=4096
)

From a collections.abc.Mapping

>>> solver = DiffEqSolver.from_({"solver": dfx.Dopri5(),
...       "stepsize_controller": dfx.PIDController(rtol=1e-5, atol=1e-5)})
>>> solver
DiffEqSolver(
  solver=Dopri5(scan_kind=None),
  stepsize_controller=PIDController( ... ),
  adjoint=RecursiveCheckpointAdjoint(checkpoints=None),
  event=None,
  max_steps=4096
)

For a full enumeration of the ways to construct a DiffEqSolver object, see diffraxtra.DiffEqSolver.from_.

VectorizedDenseInterpolation

Vectorized wrapper around a diffrax.DenseInterpolation

This also works on non-batched interpolations.

>>> import jax
>>> import jax.numpy as jnp
>>> import diffrax as dfx

We'll start with a non-batched interpolation:

>>> vector_field = lambda t, y, args: -y
>>> term = dfx.ODETerm(vector_field)
>>> solver = dfx.Dopri5()
>>> ts = jnp.array([0.0, 1, 2, 3])
>>> saveat = dfx.SaveAt(ts=ts, dense=True)
>>> stepsize_controller = dfx.PIDController(rtol=1e-5, atol=1e-5)

>>> sol = dfx.diffeqsolve(
...     term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
...     stepsize_controller=stepsize_controller)
>>> interp = VectorizedDenseInterpolation(sol.interpolation)
>>> interp
VectorizedDenseInterpolation(
  scalar_interpolation=DenseInterpolation(
    ts=f64[1,4097],
    ts_size=weak_i64[1],
    infos={'k': f64[1,4096,7], 'y0': f64[1,4096], 'y1': f64[1,4096]},
    interpolation_cls=<class 'diffrax._solver.dopri5._Dopri5Interpolation'>,
    direction=weak_i64[1],
    t0_if_trivial=f64[1],
    y0_if_trivial=f64[1]
  ),
  batch_shape=(),
  y0_shape=()
)

This can be evaluated by the normal means:

>>> interp.evaluate(ts[-1])  # scalar evaluation
Array(0.04978961, dtype=float64)

It also works on arrays, without needed to manually apply jax.vmap:

>>> interp.evaluate(ts)  # It works on arrays!
Array([1. , 0.36788338, 0.13533922, 0.04978961], dtype=float64)
>>> interp.evaluate(ts, ts[0])  # t1 - t0 mixed scalar and array
Array([0. , 0.63211662, 0.86466078, 0.95021039], dtype=float64)

Better yet, the time array may be arbitrarily shaped:

>>> interp.evaluate(ts.reshape(2, 2)).round(3)
Array([[1.   , 0.368],
       [0.135, 0.05 ]], dtype=float64)

As a convenience, we can also apply the VectorizedDenseInterpolation to the solution to modify the interpolation "in-place" (when in a jitted context, otherwise out-of-place, returning a copy):

>>> sol = VectorizedDenseInterpolation.apply_to_solution(sol)
>>> isinstance(sol, dfx.Solution)
True
>>> isinstance(sol.interpolation, VectorizedDenseInterpolation)
True

Now we'll batch the interpolation:

>>> @jax.vmap
... def solve(y0):
...     sol = dfx.diffeqsolve(
...         term, solver, t0=0, t1=3, dt0=0.1, y0=y0, saveat=saveat,
...         stepsize_controller=stepsize_controller)
...     return sol
>>> sol = solve(jnp.array([1, 2, 3]))
>>> interp = VectorizedDenseInterpolation(sol.interpolation)
>>> interp.evaluate(ts[-1]).round(3)  # scalar eval of batched interp
Array([0.05 , 0.1  , 0.149], dtype=float64)
>>> interp.evaluate(ts).astype(jnp.float64).round(3)  # array eval of batched interp
Array([[1.   , 0.368, 0.135, 0.05 ],
       [2.   , 0.736, 0.271, 0.1  ],
       [3.   , 1.104, 0.406, 0.149]], dtype=float64)
>>> interp.evaluate(ts, ts[0]).round(3)  # mixed scalar and array eval
Array([[0.   , 0.632, 0.865, 0.95 ],
       [0.   , 1.264, 1.729, 1.9  ],
       [0.   , 1.896, 2.594, 2.851]], dtype=float64)
>>> ys = interp.evaluate(ts.reshape(2, 2)).round(3)  # arbitrary shape eval
>>> ys
Array([[[1.   , 0.368],
        [0.135, 0.05 ]],
        [[2.   , 0.736],
        [0.271, 0.1  ]],
        [[3.   , 1.104],
        [0.406, 0.149]]], dtype=float64)
>>> ys.shape  # (batch, *times)
(3, 2, 2)

Citation

DOI

If you enjoyed using this library and would like to cite the software you use then click the link above.

Development

Actions Status codecov SPEC 0 — Minimum Supported Dependencies pre-commit ruff

We welcome contributions!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

diffraxtra-1.4.0.tar.gz (72.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

diffraxtra-1.4.0-py3-none-any.whl (15.9 kB view details)

Uploaded Python 3

File details

Details for the file diffraxtra-1.4.0.tar.gz.

File metadata

  • Download URL: diffraxtra-1.4.0.tar.gz
  • Upload date:
  • Size: 72.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for diffraxtra-1.4.0.tar.gz
Algorithm Hash digest
SHA256 61e12aa42179b4dee76ade4ff11e5ab5c142d0a3a65a7a19da292c8c6205d38c
MD5 39cc6bb3243470e0b1e687beda2d70aa
BLAKE2b-256 b911888cc09e01fab184aaff164bf88bd631069abc20741dfb18f7f941eec09c

See more details on using hashes here.

Provenance

The following attestation bundles were made for diffraxtra-1.4.0.tar.gz:

Publisher: cd.yml on GalacticDynamics/diffraxtra

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file diffraxtra-1.4.0-py3-none-any.whl.

File metadata

  • Download URL: diffraxtra-1.4.0-py3-none-any.whl
  • Upload date:
  • Size: 15.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for diffraxtra-1.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 6b8fcda07eb3cb3c69f513765c85a8c02acf9b187d28085b8ab25306eca36a56
MD5 5b28ad17502db2ac7761167b23af51dd
BLAKE2b-256 88d32ab56beeb5172bcf9de7bdbcf2233ae196435a10eb1c8909ff1a61ec88f1

See more details on using hashes here.

Provenance

The following attestation bundles were made for diffraxtra-1.4.0-py3-none-any.whl:

Publisher: cd.yml on GalacticDynamics/diffraxtra

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page