Skip to main content

A parallel ODE solver for PyTorch

Project description

A Parallel ODE Solver for PyTorch

pytest

torchode is a suite of single-step ODE solvers such as dopri5 or tsit5 that are compatible with PyTorch's JIT compiler and parallelized across a batch. JIT compilation often gives a performance boost, especially for code with many small operations such as an ODE solver, while batch-parallelization means that the solver can take a step of 0.1 for one sample and 0.33 for another, depending on each sample's difficulty. This can avoid performance traps for models of varying stiffness and ensures that the model's predictions are independent from the compisition of the batch. See the paper for details.

If you get stuck at some point, you think the library should have an example on x or you want to suggest some other type of improvement, please open an issue on github.

Installation

You can get the latest released version from PyPI with

pip install torchode

To install a development version, clone the repository and install in editable mode:

git clone https://github.com/martenlienen/torchode
cd torchode
pip install -e .

Usage

import matplotlib.pyplot as pp
import torch
import torchode as to

def f(t, y):
    return -0.5 * y

y0 = torch.tensor([[1.2], [5.0]])
n_steps = 10
t_eval = torch.stack((torch.linspace(0, 5, n_steps), torch.linspace(3, 4, n_steps)))

term = to.ODETerm(f)
step_method = to.Dopri5(term=term)
step_size_controller = to.IntegralController(atol=1e-6, rtol=1e-3, term=term)
solver = to.AutoDiffAdjoint(step_method, step_size_controller)
jit_solver = torch.compile(solver)

sol = jit_solver.solve(to.InitialValueProblem(y0=y0, t_eval=t_eval))
print(sol.stats)
# => {'n_f_evals': tensor([26, 26]), 'n_steps': tensor([4, 2]),
# =>  'n_accepted': tensor([4, 2]), 'n_initialized': tensor([10, 10])}

pp.plot(sol.ts[0], sol.ys[0])
pp.plot(sol.ts[1], sol.ys[1])

Citation

If you build upon this work, please cite the following paper.

@inproceedings{lienen2022torchode,
  title = {torchode: A Parallel {ODE} Solver for PyTorch},
  author = {Marten Lienen and Stephan G{\"u}nnemann},
  booktitle = {The Symbiosis of Deep Learning and Differential Equations II, NeurIPS},
  year = {2022},
  url = {https://openreview.net/forum?id=uiKVKTiUYB0}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchode-1.0.1.tar.gz (37.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchode-1.0.1-py3-none-any.whl (30.4 kB view details)

Uploaded Python 3

File details

Details for the file torchode-1.0.1.tar.gz.

File metadata

  • Download URL: torchode-1.0.1.tar.gz
  • Upload date:
  • Size: 37.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for torchode-1.0.1.tar.gz
Algorithm Hash digest
SHA256 c0faea451a416dc027fa5d93562d661b6f8dbd1da174cd2b38d3377e77a5acee
MD5 f9e7a9b080ddf7d07f2e51f4e8f0600b
BLAKE2b-256 9eaa36cbb79022a17f0cf0127e22693201ed01a8bbfa5efb1f854fe90739ad22

See more details on using hashes here.

File details

Details for the file torchode-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: torchode-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 30.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for torchode-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 edbd2f91997a348b85ead40b276e04cd04d6bf3ec1df3e8af4128560c772a953
MD5 af37e59ba1e65ff954cb4e8e603f7dd6
BLAKE2b-256 6b3a1ec5df4e4301f50d05112a12f5c8f11ebc1b1da03b3250d8889de4ffd04a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page