Fit time-series data with a Neural Differential Equation
Project description
NODEFit
(https://pepy.tech/project/nodefit)
Fit time-series data with a Neural Differential Equation!
This repository contains time-series data fit capabilities using both Neural Ordinary Differential Equations and Neural Stochastic Differential Equations
GPU support is packaged as part of PyTorch
How to install and execute?
Tested on Python 3.9
Just run
pip install nodefit
The following program illustrates a basic example
import numpy as np
import torch.nn as nn
from nodefit.constants import DEVICE
from nodefit.neural_ode import NeuralODE
from nodefit.neural_sde import NeuralSDE
###
# DEFINE NETWORKS
###
# Neural ODE parameters
ndim, drift_nhidden, diffusion_nhidden = 2, 10, 2
drift_nn = nn.Sequential(
nn.Linear(ndim+1, drift_nhidden),
nn.Sigmoid(),
nn.Linear(drift_nhidden, ndim)
).double().to(DEVICE)
diffusion_nn = nn.Sequential(
nn.Linear(ndim+1, diffusion_nhidden),
nn.Sigmoid(),
nn.Linear(diffusion_nhidden, ndim)
).double().to(DEVICE)
###
# PROVIDE DATA
###
# Training between data for 0 and 5 seconds
t = np.linspace(0, 5, 10)
# Provide data as list of lists with starting condition
data = np.array([[...]])
###
# FIT USING NEURALODE
###
print('Performing fit using Neural ODE...')
neural_ode = NeuralODE(drift_nn, t, data)
neural_ode.train(2000)
# # Extrapolate the training data to 10 seconds
extra_data = neural_ode.extrapolate(10)
neural_ode.plot(extra_data)
###
# FIT USING NEURALSDE
###
print('Performing fit using Neural SDE...')
neural_sde = NeuralSDE(drift_nn, diffusion_nn, t, data)
neural_sde.train(1)
# # Extrapolate the training data to 10 seconds
extra_data = neural_sde.extrapolate(10)
neural_sde.plot(extra_data)
Sample Output
Whom to contact?
Please direct your queries to gpavanb1 for any questions.
Acknowledgements
This package would not be possible without the supporting packages - torchdiffeq and torchsde
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
NODEFit-0.1.tar.gz
(5.4 kB
view details)
File details
Details for the file NODEFit-0.1.tar.gz
.
File metadata
- Download URL: NODEFit-0.1.tar.gz
- Upload date:
- Size: 5.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | c36790b0d374fd156d178e93b72a6d76f5c8f98f56fa0820a43cd1c09f6446d9 |
|
MD5 | 81333145dadd1ab532e00a80f675e242 |
|
BLAKE2b-256 | 247d1cf1a2cfbcd48378ea6505c38f48a6d5adbab04e29ceb86fd85dc24d3ecc |