Skip to main content

diffrax extras: OOP and vectorization

Project description

diffraxtra

diffrax extras

PyPI: diffraxtra PyPI versions: diffraxtra diffraxtra license

CI status codecov ruff ruff pre-commit


Extras for diffrax.

  • DiffEqSolver: an object-oriented interface to diffrax.diffeqsolve.
  • VectorizedDenseInterpolation: a vectorized form of diffrax.DenseInterpolation that works on batched results from diffrax.diffeqsolve.

For example,

import jax.numpy as jnp
import diffrax as dfx
from diffraxtra import DiffEqSolver

# Construct a solver object.
solver = DiffEqSolver(dfx.Dopri5(),
                      stepsize_controller=dfx.PIDController(rtol=1e-5, atol=1e-5))

# And a differential equation to solve.
term = dfx.ODETerm(lambda t, y, args: -y)

# Then solve the differential equation.
saveat = dfx.SaveAt(t1=True, dense=True)
soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
              vectorize_interpolation=True)

print(soln)
# Solution(
#   t0=f32[], t1=f32[], ts=f32[1],
#   ys=f32[1],
#   interpolation=VectorizedDenseInterpolation(
#     scalar_interpolation=DenseInterpolation( ... ),
#     batch_shape=(),
#     y0_shape=()
#   ),
#   ...
# )

soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))
# Array([[0.90483742, 0.81872516],
#         [0.74080871, 0.67031456]], dtype=float64)

Installation

PyPI platforms PyPI version

pip install diffraxtra

Documentation

DiffEqSolver

>>> import jax.numpy as jnp
>>> import diffrax as dfx
>>> from diffraxtra import DiffEqSolver

Construct a solver object.

>>> solver = DiffEqSolver(dfx.Dopri5(),
...                stepsize_controller=dfx.PIDController(rtol=1e-5, atol=1e-5))

And a differential equation to solve.

>>> term = dfx.ODETerm(lambda t, y, args: -y)

Then solve the differential equation.

>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1)
>>> soln
Solution( t0=f64[], t1=f64[], ts=f64[1],
          ys=f64[1], ... )

The solution can be saved at specific times.

>>> saveat = dfx.SaveAt(ts=[0., 1., 2., 3.])
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)
>>> soln
Solution( t0=f64[], t1=f64[], ts=f64[4],
          ys=f64[4], ... )

The solution can be densely interpolated.

>>> saveat = dfx.SaveAt(t1=True, dense=True)
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)
>>> soln
Solution( t0=f64[], t1=f64[], ts=f64[1],
          ys=f64[1], ... )
>>> soln.evaluate(0.5).round(3)
Array(0.607, dtype=float64)

Using the VectorizedDenseInterpolation class, the interpolation can be vectorized, enabling evaluation of batched solutions over batches of times.

>>> from diffraxtra import VectorizedDenseInterpolation
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)
>>> soln = VectorizedDenseInterpolation.apply_to_solution(soln)
>>> soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))
Array([[0.90483742, 0.81872516],
       [0.74080871, 0.67031456]], dtype=float64)

This can be more conveniently done using the vectorize_interpolation argument.

>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
...               vectorize_interpolation=True)
>>> soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))
Array([[0.90483742, 0.81872516],
       [0.74080871, 0.67031456]], dtype=float64)

There are many ways to construct a DiffEqSolver object. For example, we can can make a new one from an existing DiffEqSolver object

>>> solver = DiffEqSolver(dfx.Dopri5())
>>> DiffEqSolver.from_(solver) is solver
True

From a diffrax.AbstractSolver object.

>>> solver = DiffEqSolver.from_(dfx.Dopri5())
>>> solver
DiffEqSolver(solver=Dopri5())

(Where all other arguments are their default values and printed only if changed.)

From a collections.abc.Mapping

>>> solver = DiffEqSolver.from_({"solver": dfx.Dopri5(),
...       "stepsize_controller": dfx.PIDController(rtol=1e-5, atol=1e-5)})
>>> solver
DiffEqSolver(
  solver=Dopri5(), stepsize_controller=PIDController(rtol=1e-05, atol=1e-05)
)

For a full enumeration of the ways to construct a DiffEqSolver object, see diffraxtra.DiffEqSolver.from_.

VectorizedDenseInterpolation

Vectorized wrapper around a diffrax.DenseInterpolation

This also works on non-batched interpolations.

>>> import jax
>>> import jax.numpy as jnp
>>> import diffrax as dfx

We'll start with a non-batched interpolation:

>>> vector_field = lambda t, y, args: -y
>>> term = dfx.ODETerm(vector_field)
>>> solver = dfx.Dopri5()
>>> ts = jnp.array([0.0, 1, 2, 3])
>>> saveat = dfx.SaveAt(ts=ts, dense=True)
>>> stepsize_controller = dfx.PIDController(rtol=1e-5, atol=1e-5)

>>> sol = dfx.diffeqsolve(
...     term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
...     stepsize_controller=stepsize_controller)
>>> interp = VectorizedDenseInterpolation(sol.interpolation)
>>> interp
VectorizedDenseInterpolation(
  scalar_interpolation=DenseInterpolation(
    ts=f64[1,4097],
    ts_size=weak_i64[1],
    infos={'k': f64[1,4096,7], 'y0': f64[1,4096], 'y1': f64[1,4096]},
    interpolation_cls=<class 'diffrax._solver.dopri5._Dopri5Interpolation'>,
    direction=weak_i64[1],
    t0_if_trivial=f64[1],
    y0_if_trivial=f64[1]
  ),
  batch_shape=(),
  y0_shape=()
)

This can be evaluated by the normal means:

>>> interp.evaluate(ts[-1])  # scalar evaluation
Array(0.04978961, dtype=float64)

It also works on arrays, without needed to manually apply jax.vmap:

>>> interp.evaluate(ts)  # It works on arrays!
Array([1. , 0.36788338, 0.13533922, 0.04978961], dtype=float64)
>>> interp.evaluate(ts, ts[0])  # t1 - t0 mixed scalar and array
Array([0. , 0.63211662, 0.86466078, 0.95021039], dtype=float64)

Better yet, the time array may be arbitrarily shaped:

>>> interp.evaluate(ts.reshape(2, 2)).round(3)
Array([[1.   , 0.368],
       [0.135, 0.05 ]], dtype=float64)

As a convenience, we can also apply the VectorizedDenseInterpolation to the solution to modify the interpolation "in-place" (when in a jitted context, otherwise out-of-place, returning a copy):

>>> sol = VectorizedDenseInterpolation.apply_to_solution(sol)
>>> isinstance(sol, dfx.Solution)
True
>>> isinstance(sol.interpolation, VectorizedDenseInterpolation)
True

Now we'll batch the interpolation:

>>> @jax.vmap
... def solve(y0):
...     sol = dfx.diffeqsolve(
...         term, solver, t0=0, t1=3, dt0=0.1, y0=y0, saveat=saveat,
...         stepsize_controller=stepsize_controller)
...     return sol
>>> sol = solve(jnp.array([1, 2, 3]))
>>> interp = VectorizedDenseInterpolation(sol.interpolation)
>>> interp.evaluate(ts[-1]).round(3)  # scalar eval of batched interp
Array([0.05 , 0.1  , 0.149], dtype=float64)
>>> interp.evaluate(ts).astype(jnp.float64).round(3)  # array eval of batched interp
Array([[1.   , 0.368, 0.135, 0.05 ],
       [2.   , 0.736, 0.271, 0.1  ],
       [3.   , 1.104, 0.406, 0.149]], dtype=float64)
>>> interp.evaluate(ts, ts[0]).round(3)  # mixed scalar and array eval
Array([[0.   , 0.632, 0.865, 0.95 ],
       [0.   , 1.264, 1.729, 1.9  ],
       [0.   , 1.896, 2.594, 2.851]], dtype=float64)
>>> ys = interp.evaluate(ts.reshape(2, 2)).round(3)  # arbitrary shape eval
>>> ys
Array([[[1.   , 0.368],
        [0.135, 0.05 ]],
        [[2.   , 0.736],
        [0.271, 0.1  ]],
        [[3.   , 1.104],
        [0.406, 0.149]]], dtype=float64)
>>> ys.shape  # (batch, *times)
(3, 2, 2)

Citation

DOI

If you enjoyed using this library and would like to cite the software you use then click the link above.

Development

Actions Status codecov SPEC 0 — Minimum Supported Dependencies pre-commit ruff

We welcome contributions!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

diffraxtra-1.5.0.tar.gz (73.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

diffraxtra-1.5.0-py3-none-any.whl (16.0 kB view details)

Uploaded Python 3

File details

Details for the file diffraxtra-1.5.0.tar.gz.

File metadata

  • Download URL: diffraxtra-1.5.0.tar.gz
  • Upload date:
  • Size: 73.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for diffraxtra-1.5.0.tar.gz
Algorithm Hash digest
SHA256 5e49aa3f5e4aa79d2d0a596cedac5c82bbc81ff6782c8ce7c3f2f465c31467ac
MD5 10d6086d2c25380f0e45bd76da21c65d
BLAKE2b-256 8b324b0c38e0d0c74354701463cb221485531d913ca34b5820f078cd4804a2b1

See more details on using hashes here.

Provenance

The following attestation bundles were made for diffraxtra-1.5.0.tar.gz:

Publisher: cd.yml on GalacticDynamics/diffraxtra

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file diffraxtra-1.5.0-py3-none-any.whl.

File metadata

  • Download URL: diffraxtra-1.5.0-py3-none-any.whl
  • Upload date:
  • Size: 16.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for diffraxtra-1.5.0-py3-none-any.whl
Algorithm Hash digest
SHA256 af6ba0a3cd0560f8a4539575a06739ffd846d849bd8724dcbb1534b7202e142e
MD5 bee85427d42138a48b1bd10bdd12cb93
BLAKE2b-256 b91df3d89bf7c5f6e96378ee2b6d9f1f209d3d2d3986ee232c98cc02bde4b7d1

See more details on using hashes here.

Provenance

The following attestation bundles were made for diffraxtra-1.5.0-py3-none-any.whl:

Publisher: cd.yml on GalacticDynamics/diffraxtra

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page