Skip to main content

A parallel ODE solver for PyTorch

Project description

A Parallel ODE Solver for PyTorch

pytest

torchode is a suite of single-step ODE solvers such as dopri5 or tsit5 that are compatible with PyTorch's JIT compiler and parallelized across a batch. JIT compilation often gives a performance boost, especially for code with many small operations such as an ODE solver, while batch-parallelization means that the solver can take a step of 0.1 for one sample and 0.33 for another, depending on each sample's difficulty. This can avoid performance traps for models of varying stiffness and ensures that the model's predictions are independent from the compisition of the batch. See the paper for details.

If you get stuck at some point, you think the library should have an example on x or you want to suggest some other type of improvement, please open an issue on github.

Installation

You can get the latest released version from PyPI with

pip install torchode

To install a development version, clone the repository and install in editable mode:

git clone https://github.com/martenlienen/torchode
cd torchode
pip install -e .

Usage

import matplotlib.pyplot as pp
import torch
import torchode as to

def f(t, y):
    return -0.5 * y

y0 = torch.tensor([[1.2], [5.0]])
n_steps = 10
t_eval = torch.stack((torch.linspace(0, 5, n_steps), torch.linspace(3, 4, n_steps)))

term = to.ODETerm(f)
step_method = to.Dopri5(term=term)
step_size_controller = to.IntegralController(atol=1e-6, rtol=1e-3, term=term)
solver = to.AutoDiffAdjoint(step_method, step_size_controller)
jit_solver = torch.compile(solver)

sol = jit_solver.solve(to.InitialValueProblem(y0=y0, t_eval=t_eval))
print(sol.stats)
# => {'n_f_evals': tensor([26, 26]), 'n_steps': tensor([4, 2]),
# =>  'n_accepted': tensor([4, 2]), 'n_initialized': tensor([10, 10])}

pp.plot(sol.ts[0], sol.ys[0])
pp.plot(sol.ts[1], sol.ys[1])

Citation

If you build upon this work, please cite the following paper.

@inproceedings{lienen2022torchode,
  title = {torchode: A Parallel {ODE} Solver for PyTorch},
  author = {Marten Lienen and Stephan G{\"u}nnemann},
  booktitle = {The Symbiosis of Deep Learning and Differential Equations II, NeurIPS},
  year = {2022},
  url = {https://openreview.net/forum?id=uiKVKTiUYB0}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchode-1.0.0.tar.gz (37.0 kB view details)

Uploaded Source

Built Distribution

torchode-1.0.0-py3-none-any.whl (30.3 kB view details)

Uploaded Python 3

File details

Details for the file torchode-1.0.0.tar.gz.

File metadata

  • Download URL: torchode-1.0.0.tar.gz
  • Upload date:
  • Size: 37.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.6

File hashes

Hashes for torchode-1.0.0.tar.gz
Algorithm Hash digest
SHA256 fd308ef0424ec4d82669135ec60658c67f6b8a9feea8992464c04ef1e936040e
MD5 3602152a2c1f7aa2257efa68dae55d48
BLAKE2b-256 b5775643ce903cb0e316155ac2b685c58108789090aa258d6dddca4e94f568fa

See more details on using hashes here.

File details

Details for the file torchode-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: torchode-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 30.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.6

File hashes

Hashes for torchode-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 df0883aa0e98ba5118a3d768ec2d543c27ff6fc205b866caa2a98a8f2545122f
MD5 97fa0d6d24abc83e07791445cde2ccce
BLAKE2b-256 14d50cdc30970404572991cd1ff3d485d86b1a6aa293c4ad39f5d07da135a4f8

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page