Bridges `DifferentialEquations.jl` with PyTorch.
Project description
diffeqtorch
Bridges DifferentialEquations.jl
with PyTorch. Besides benefitting from the huge range of solvers available in DifferentialEquations.jl
, this allows taking gradients through solvers using local sensitivity analysis/auto-diff. The package has only been tested with ODE problems, and in particular, automatic differentiation is only supported for ODEs using ForwardDiff.jl
. This can be extended in the future, contributions are welcome.
Examples
- Simple ODE problem to demonstrate the interface and confirm gradients with analytical solution
- SIR model for a slighlty more complicated model with numerical gradient checking
- Hodgkin-Huxley model for a realistic example from Neuroscience
Installation
Prerequisites for using diffeqtorch
are installation of Julia and Python. Note that the binary directory of julia
needs to be in your PATH
.
We recommend using a custom Julia sytem image containing dependencies. If the environment variable JULIA_SYSIMAGE_DIFFEQTORCH
is set, the installation script will automatically build the image. This may take a while but will improve speed afterwards.
Install diffeqtorch
:
$ export JULIA_SYSIMAGE_DIFFEQTORCH="$HOME/.julia_sysimage_diffeqtorch.so"
$ pip install diffeqtorch -v
Usage
from diffeqtorch import DiffEq
f = """
function f(du,u,p,t)
du[1] = p[1] * u[1]
end
"""
de = DiffEq(f)
u0 = torch.tensor([1.])
tspan = torch.tensor([0., 3.])
p = torch.tensor([1.01])
u, t = de(u0, tspan, p)
See also help(DiffEq)
and examples provided in notebooks/
.
License
MIT
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for diffeqtorch-0.1.1-py2.py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 75b8170beab7a9c5124f3c58d2083461116610b08b6ee738a21b1cc8f6e2b357 |
|
MD5 | e34d41ae2e20c142b7a1b52854f721e3 |
|
BLAKE2b-256 | e89a0c8ff28975021d070c9908de34fd6583fc9680b5001aebd6c8ee225106e8 |