Skip to main content

Structure-preserving neural networks

Project description

STRUPNET: structure-preserving neural networks

This package implements structure-preserving neural networks for learning dynamics of differential systems from data.

Installing

Install it using pip: pip install strupnet

Symplectic neural networks (SympNets)

Basic example

import torch
from strupnet import SympNet

dim=2 # degrees of freedom for the Hamiltonian system. x = (p, q) \in R^{2*dim}

# Define a symplectic neural network with random parameters:
symp_net = SympNet(dim=dim, layers=12, width=8)

x0 = torch.randn(2 * dim) # phase space coordinate x = (p, q) 
h = torch.tensor([0.1]) # time-step 

x1 = symp_net(x0, h) # defines a random but symplectic transformation from x0 to x1

Training a SympNet

SympNet inherits from torch.nn.Module and can therefore be trained like a pytorch module. Here is a minimal working example of training a SympNet using quadratic ridge polynomials (which is best for quadratic Hamiltonians).

Generating data

We will generate data of the form $ \{x(ih)\}_{i=0}^{n+1}=\{p(ih), q(ih)\}_{i=0}^{n+1}$, where $x(t)$ is the solution to the Hamiltonian ODE $\dot{x} = J\nabla H $, with the simple Harmonic oscillator Hamiltonian $H = \frac{1}{2} (p^2 + q^2)$. The data is arranged in the form $x_0 = \{x(ih)\}_{i=0}^{n}$, $x_1 = \{x((i+1)h)\}_{i=0}^{n}$ and same for $t$.

import torch 

# Generate training and testing data using simple harmonic oscillator solution
def simple_harmonic_oscillator_solution(t_start, t_end, timestep):
    time_grid = torch.linspace(t_start, t_end, int((t_end-t_start)/timestep)+1)
    p_sol = torch.cos(time_grid)
    q_sol = torch.sin(time_grid)
    pq_sol = torch.stack([p_sol, q_sol], dim=-1)
    return pq_sol, time_grid.unsqueeze(dim=1)

timestep=0.05

x_train, t_train = simple_harmonic_oscillator_solution(t_start=0, t_end=1, timestep=timestep)
x_test, t_test = simple_harmonic_oscillator_solution(t_start=1, t_end=4, timestep=timestep)

x0_train, x1_train, t0_train, t1_train = x_train[:-1, :], x_train[1:, :], t_train[:-1, :], t_train[1:, :]
x0_test, x1_test, t0_test, t1_test = x_test[:-1, :], x_test[1:, :], t_test[:-1, :], t_test[1:, :]

Training

We can train a SympNet like any PyTorch module on the loss function defined as follows. Letting $\Phi_h^{\theta}(x)$ denote the SympNet, where $\theta$ denotes its set of trainable parameters, then we want to find $\theta$ that minimises

$\qquad loss=\sum_{i=0}^{n}|\Phi_h^{\theta}(x(ih))-x\left((i+1)h\right)|^2$

from strupnet import SympNet

# Initialize Symplectic Neural Network
symp_net = SympNet(dim=1, layers=2, max_degree=2, method="P")

# Train it like any other PyTorch model
optimizer = torch.optim.Adam(symp_net.parameters(), lr=0.01)
mse = torch.nn.MSELoss()
for epoch in range(1000):
    optimizer.zero_grad()    
    x1_pred = symp_net(x=x0_train, dt=t1_train - t0_train)
    loss = mse(x1_train, x1_pred)
    loss.backward()
    optimizer.step()
print("final loss value: ", loss.item())

x1_test_pred = symp_net(x=x0_test, dt=t1_test - t0_test)
print("test set error", torch.norm(x1_test_pred - x1_test).item())

Outputs:

Final loss value:  2.1763008371575767e-33
test set error 5.992433957888383e-16

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

strupnet-0.0.2.tar.gz (13.8 kB view details)

Uploaded Source

Built Distribution

strupnet-0.0.2-py3-none-any.whl (16.7 kB view details)

Uploaded Python 3

File details

Details for the file strupnet-0.0.2.tar.gz.

File metadata

  • Download URL: strupnet-0.0.2.tar.gz
  • Upload date:
  • Size: 13.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.6

File hashes

Hashes for strupnet-0.0.2.tar.gz
Algorithm Hash digest
SHA256 a7517e2cb7c14b16068815db282f6d5d48192ec9d477d2b568e8d89442cfc057
MD5 5fd70b35da00113ac8b73f6e9a223324
BLAKE2b-256 731b0375178e24950587f93322c351234162ed8ddb73f713d59817780222f7cb

See more details on using hashes here.

File details

Details for the file strupnet-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: strupnet-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 16.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.6

File hashes

Hashes for strupnet-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 d6d3b5e199d4deeb717750481e493ae92fa91b767ae4f6e25acd3a8fa6688249
MD5 5577523cbd2cc6e99bb1701a8530798b
BLAKE2b-256 cfad68424f28a14d1acd405b1063afdf24ac8e8c4bba6d4c72994eacaec99d76

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page