Skip to main content

neural differential equations

Project description

neural-diffeqs

PyPI pyversions PyPI version Code style: black

A PyTorch-based library for the instantiation of neural differential equations.

Installation

To install with pip:

pip install neural_diffeqs

To install the development version from GitHub:

git clone https://github.com/mvinyard/neural-diffeqs.git; cd ./neural-diffeqs
pip install -e .

Examples

You can instantiate an SDE or ODE as follows:

from neural_diffeqs import neural_diffeq

SDE = neural_diffeq()
# this can be passed to `torchsde.sdeint`

ODE = neural_diffeq(sigma_hidden=False)
# this can be passed to `torchdiffeq.odeint`

You can also define the SDE or ODE as potential functions. These can be passed to torchsde.sdeint and torchdiffeq.odeint just the same as above:

from neural_diffeqs import neural_diffeq

SDE = neural_diffeq(mu_potential=True, sigma_potential=False)

ODE = neural_diffeq(sigma_hidden=False)

There are several other parameters that are easily tweakable, including the composition of the neural network(s), using the following arguments:

To adjust the parameters of the mu neural network:

  • mu_hidden - a dict (e.g.,: {1:[400,400], 2:[400,400]})
  • mu_in_dim
  • mu_out_dim
  • mu_potential - if this parameter is True, the output dimension of the output layer is changed to 1.
  • mu_init_potential - when mu_potential = True, this argument initializes the output value of mu. By default, this returns a torch.zeros([]) tensor.
  • mu_activation_function
  • mu_dropout

Similarly, the sigma neural network can be controlled with these parameters:

  • sigma_hidden - a dict (e.g.,: {1:[400,400], 2:[400,400]})
  • sigma_in_dim
  • sigma_out_dim
  • sigma_potential - if this parameter is True, the output dimension of the output layer is changed to 1.
  • sigma_init_potential - when sigma_potential = True, this argument initializes the output value of sigma. By default, this returns a torch.zeros([]) tensor.
  • sigma_activation_function
  • sigma_dropout

There are also general parameters that are passed / required of the SDE when using the torchsde interface:

  • brownian_size
  • noise_type
  • sde_type

For more examples, please see the notebooks in ./examples/. For documentation related neural ODEs and torchdiffeq, see the torchdiffeq repository. For documentation related to neural SDEs and torchsde, see the torchsde repository.

To-do and/or potential directions:

  • Integration of neural controlled differential equations (neural CDEs).
  • Build SDE-GANs
  • Neural PDEs

Questions or suggestions? Open an issue or send an email to Michael Vinyard.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

neural-diffeqs-0.2.1rc0.tar.gz (19.3 kB view details)

Uploaded Source

Built Distribution

neural_diffeqs-0.2.1rc0-py3-none-any.whl (20.3 kB view details)

Uploaded Python 3

File details

Details for the file neural-diffeqs-0.2.1rc0.tar.gz.

File metadata

  • Download URL: neural-diffeqs-0.2.1rc0.tar.gz
  • Upload date:
  • Size: 19.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.3

File hashes

Hashes for neural-diffeqs-0.2.1rc0.tar.gz
Algorithm Hash digest
SHA256 8a3a077e5df34898064de4f95c81d23eb772138094a5e9b96a7d603c30cc42ba
MD5 96a2089e3dfbeca95f0f3d810b791b4b
BLAKE2b-256 4a01c95f915e6aa32170522d41ee676edba0d25aa7991bad3358861f4dcb2bb5

See more details on using hashes here.

File details

Details for the file neural_diffeqs-0.2.1rc0-py3-none-any.whl.

File metadata

File hashes

Hashes for neural_diffeqs-0.2.1rc0-py3-none-any.whl
Algorithm Hash digest
SHA256 364ae0e8432eed15efdabc674d047336aa0d98a4af52bac089db959c830953eb
MD5 05fc91888ece32bc1efeda2c9d9126f9
BLAKE2b-256 008e7cda983ef25d01fd08e5227d38eb9dc5ebea399891111fdcee6abdf6ba6a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page