diffrax extras: OOP and vectorization
Project description
diffraxtra
Tools from dataclasses, extended to all of Python
Extras for diffrax.
DiffEqSolver: an object-oriented interface todiffrax.diffeqsolve.VectorizedDenseInterpolation: a vectorized form ofdiffrax.DenseInterpolationthat works on batched results fromdiffrax.diffeqsolve.
For example,
import jax.numpy as jnp
import diffrax as dfx
from diffraxtra import DiffEqSolver
# Construct a solver object.
solver = DiffEqSolver(dfx.Dopri5(),
stepsize_controller=dfx.PIDController(rtol=1e-5, atol=1e-5))
# And a differential equation to solve.
term = dfx.ODETerm(lambda t, y, args: -y)
# Then solve the differential equation.
saveat = dfx.SaveAt(t1=True, dense=True)
soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
vectorize_interpolation=True)
print(soln)
# Solution(
# t0=f32[], t1=f32[], ts=f32[1],
# ys=f32[1],
# interpolation=VectorizedDenseInterpolation(
# scalar_interpolation=DenseInterpolation( ... ),
# batch_shape=(),
# y0_shape=()
# ),
# ...
# )
soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))
# Array([[0.90483742, 0.81872516],
# [0.74080871, 0.67031456]], dtype=float64)
Installation
pip install diffraxtra
Documentation
DiffEqSolver
>>> import jax.numpy as jnp
>>> import diffrax as dfx
>>> from diffraxtra import DiffEqSolver
Construct a solver object.
>>> solver = DiffEqSolver(dfx.Dopri5(),
... stepsize_controller=dfx.PIDController(rtol=1e-5, atol=1e-5))
And a differential equation to solve.
>>> term = dfx.ODETerm(lambda t, y, args: -y)
Then solve the differential equation.
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1)
>>> soln
Solution( t0=f64[], t1=f64[], ts=f64[1],
ys=f64[1], ... )
The solution can be saved at specific times.
>>> saveat = dfx.SaveAt(ts=[0., 1., 2., 3.])
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)
>>> soln
Solution( t0=f64[], t1=f64[], ts=f64[4],
ys=f64[4], ... )
The solution can be densely interpolated.
>>> saveat = dfx.SaveAt(t1=True, dense=True)
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)
>>> soln
Solution( t0=f64[], t1=f64[], ts=f64[1],
ys=f64[1], ... )
>>> soln.evaluate(0.5).round(3)
Array(0.607, dtype=float64)
Using the VectorizedDenseInterpolation class, the interpolation can be
vectorized, enabling evaluation of batched solutions over batches of times.
>>> from diffraxtra import VectorizedDenseInterpolation
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)
>>> soln = VectorizedDenseInterpolation.apply_to_solution(soln)
>>> soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))
Array([[0.90483742, 0.81872516],
[0.74080871, 0.67031456]], dtype=float64)
This can be more conveniently done using the vectorize_interpolation argument.
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
... vectorize_interpolation=True)
>>> soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))
Array([[0.90483742, 0.81872516],
[0.74080871, 0.67031456]], dtype=float64)
There are many ways to construct a DiffEqSolver object. For example, we can
can make a new one from an existing DiffEqSolver object
>>> solver = DiffEqSolver(dfx.Dopri5())
>>> DiffEqSolver.from_(solver) is solver
True
From a diffrax.AbstractSolver object.
>>> solver = DiffEqSolver.from_(dfx.Dopri5())
>>> solver
DiffEqSolver(
solver=Dopri5(scan_kind=None),
stepsize_controller=ConstantStepSize(),
adjoint=RecursiveCheckpointAdjoint(checkpoints=None)
)
From a collections.abc.Mapping
>>> solver = DiffEqSolver.from_({"solver": dfx.Dopri5(),
... "stepsize_controller": dfx.PIDController(rtol=1e-5, atol=1e-5)})
>>> solver
DiffEqSolver(
solver=Dopri5(scan_kind=None),
stepsize_controller=PIDController( ... ),
adjoint=RecursiveCheckpointAdjoint(checkpoints=None)
)
For a full enumeration of the ways to construct a DiffEqSolver object, see
galax.dynamics.integrate.DiffEqSolver.from_.
VectorizedDenseInterpolation
Vectorized wrapper around a diffrax.DenseInterpolation
This also works on non-batched interpolations.
>>> import jax
>>> import jax.numpy as jnp
>>> import diffrax as dfx
We'll start with a non-batched interpolation:
>>> vector_field = lambda t, y, args: -y
>>> term = dfx.ODETerm(vector_field)
>>> solver = dfx.Dopri5()
>>> ts = jnp.array([0.0, 1, 2, 3])
>>> saveat = dfx.SaveAt(ts=ts, dense=True)
>>> stepsize_controller = dfx.PIDController(rtol=1e-5, atol=1e-5)
>>> sol = dfx.diffeqsolve(
... term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
... stepsize_controller=stepsize_controller)
>>> interp = VectorizedDenseInterpolation(sol.interpolation)
>>> interp
VectorizedDenseInterpolation(
scalar_interpolation=DenseInterpolation(
ts=f64[1,4097],
ts_size=weak_i64[1],
infos={'k': f64[1,4096,7], 'y0': f64[1,4096], 'y1': f64[1,4096]},
interpolation_cls=<class 'diffrax._solver.dopri5._Dopri5Interpolation'>,
direction=weak_i64[1],
t0_if_trivial=f64[1],
y0_if_trivial=f64[1]
),
batch_shape=(),
y0_shape=()
)
This can be evaluated by the normal means:
>>> interp.evaluate(ts[-1]) # scalar evaluation
Array(0.04978961, dtype=float64)
It also works on arrays, without needed to manually apply jax.vmap:
>>> interp.evaluate(ts) # It works on arrays!
Array([1. , 0.36788338, 0.13533922, 0.04978961], dtype=float64)
>>> interp.evaluate(ts, ts[0]) # t1 - t0 mixed scalar and array
Array([0. , 0.63211662, 0.86466078, 0.95021039], dtype=float64)
Better yet, the time array may be arbitrarily shaped:
>>> interp.evaluate(ts.reshape(2, 2)).round(3)
Array([[1. , 0.368],
[0.135, 0.05 ]], dtype=float64)
As a convenience, we can also apply the VectorizedDenseInterpolation to the
solution to modify the interpolation "in-place" (when in a jitted context,
otherwise out-of-place, returning a copy):
>>> sol = VectorizedDenseInterpolation.apply_to_solution(sol)
>>> isinstance(sol, dfx.Solution)
True
>>> isinstance(sol.interpolation, VectorizedDenseInterpolation)
True
Now we'll batch the interpolation:
>>> @jax.vmap
... def solve(y0):
... sol = dfx.diffeqsolve(
... term, solver, t0=0, t1=3, dt0=0.1, y0=y0, saveat=saveat,
... stepsize_controller=stepsize_controller)
... return sol
>>> sol = solve(jnp.array([1, 2, 3]))
>>> interp = VectorizedDenseInterpolation(sol.interpolation)
>>> interp.evaluate(ts[-1]).round(3) # scalar eval of batched interp
Array([0.05 , 0.1 , 0.149], dtype=float64)
>>> interp.evaluate(ts).astype(jnp.float64).round(3) # array eval of batched interp
Array([[1. , 0.368, 0.135, 0.05 ],
[2. , 0.736, 0.271, 0.1 ],
[3. , 1.104, 0.406, 0.149]], dtype=float64)
>>> interp.evaluate(ts, ts[0]).round(3) # mixed scalar and array eval
Array([[0. , 0.632, 0.865, 0.95 ],
[0. , 1.264, 1.729, 1.9 ],
[0. , 1.896, 2.594, 2.851]], dtype=float64)
>>> ys = interp.evaluate(ts.reshape(2, 2)).round(3) # arbitrary shape eval
>>> ys
Array([[[1. , 0.368],
[0.135, 0.05 ]],
[[2. , 0.736],
[0.271, 0.1 ]],
[[3. , 1.104],
[0.406, 0.149]]], dtype=float64)
>>> ys.shape # (batch, *times)
(3, 2, 2)
Citation
If you enjoyed using this library and would like to cite the software you use then click the link above.
Development
We welcome contributions!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file diffraxtra-1.0.0.tar.gz.
File metadata
- Download URL: diffraxtra-1.0.0.tar.gz
- Upload date:
- Size: 71.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
cf579c49bf30743a0a7edf9988f552319593f202b226849d85c18eb3b8df9987
|
|
| MD5 |
0d7024db59da0fc1842262aa8da2ba50
|
|
| BLAKE2b-256 |
bfe0ee2f9860793460eebea1d47e2080d4bb46892e99eab568bc2102166346fb
|
Provenance
The following attestation bundles were made for diffraxtra-1.0.0.tar.gz:
Publisher:
cd.yml on GalacticDynamics/diffraxtra
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
diffraxtra-1.0.0.tar.gz -
Subject digest:
cf579c49bf30743a0a7edf9988f552319593f202b226849d85c18eb3b8df9987 - Sigstore transparency entry: 168733872
- Sigstore integration time:
-
Permalink:
GalacticDynamics/diffraxtra@0dc0560e0b4304c9841d76bc38ee9beffe048432 -
Branch / Tag:
refs/tags/v1.0.0 - Owner: https://github.com/GalacticDynamics
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
cd.yml@0dc0560e0b4304c9841d76bc38ee9beffe048432 -
Trigger Event:
release
-
Statement type:
File details
Details for the file diffraxtra-1.0.0-py3-none-any.whl.
File metadata
- Download URL: diffraxtra-1.0.0-py3-none-any.whl
- Upload date:
- Size: 13.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9d60b5e3306237b46f6061f90fa2d4ecb5c50d3c52eb0636b1189546b6dd19fc
|
|
| MD5 |
016c435ee0599b414122fa02607444b5
|
|
| BLAKE2b-256 |
cd1867b7b68b0612c782a930e711d7d103ac7680d2c9e6f86df74af0a2111c0e
|
Provenance
The following attestation bundles were made for diffraxtra-1.0.0-py3-none-any.whl:
Publisher:
cd.yml on GalacticDynamics/diffraxtra
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
diffraxtra-1.0.0-py3-none-any.whl -
Subject digest:
9d60b5e3306237b46f6061f90fa2d4ecb5c50d3c52eb0636b1189546b6dd19fc - Sigstore transparency entry: 168733885
- Sigstore integration time:
-
Permalink:
GalacticDynamics/diffraxtra@0dc0560e0b4304c9841d76bc38ee9beffe048432 -
Branch / Tag:
refs/tags/v1.0.0 - Owner: https://github.com/GalacticDynamics
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
cd.yml@0dc0560e0b4304c9841d76bc38ee9beffe048432 -
Trigger Event:
release
-
Statement type: