Skip to main content

Structure-preserving neural networks

Project description

strupnet

This package implements symplectic neural networks for learning dynamics of Hamiltonian systems from data.

Examples

Install it using pip: pip install strupnet

Basic example

import torch
from strupnet import SympNet

dim=2 # degrees of freedom for the Hamiltonian system. x = (p, q) \in R^{2*dim}

# Define a symplectic neural network with random parameters:
symp_net = SympNet(dim=dim, layers=12, width=8)

x0 = torch.randn(2 * dim) # phase space coordinate x = (p, q) 
h = torch.tensor([0.1]) # time-step 

x1 = symp_net(x0, h) # defines a random but symplectic transformation from x to X

Training a SympNet

SympNet inherits from torch.nn.Module and can therefore be trained like a pytorch module. Here is a minimal working example of training a SympNet using quadratic ridge polynomials (which is best for quadratic Hamiltonians)

Generating data

We will generate data of the form ${x(ih)}{i=0}^{n}={p(ih), q(ih)}{i=0}^{n}$, where $x(t)$ is the solution to the Hamiltonian ODE $\dot{x}=J\nabla H$, with the simple Harmonic oscillator Hamiltonian $H=\frac{1}{2}(p^2+q^2)$. The data is arranged in the form $x_0 = {x(ih)}{i=0}^{n}$, $x_1 = {x(ih)}{i=1}^{n+1}$ and same for $t$.

import torch 

# Generate training and testing data using simple harmonic oscillator solution
def simple_harmonic_oscillator_solution(t_start, t_end, timestep):
    time_grid = torch.linspace(t_start, t_end, int((t_end-t_start)/timestep)+1)
    p_sol = torch.cos(time_grid)
    q_sol = torch.sin(time_grid)
    pq_sol = torch.stack([p_sol, q_sol], dim=-1)
    return pq_sol, time_grid.unsqueeze(dim=1)

timestep=0.05

x_train, t_train = simple_harmonic_oscillator_solution(t_start=0, t_end=1, timestep=timestep)
x_test, t_test = simple_harmonic_oscillator_solution(t_start=1, t_end=4, timestep=timestep)

x0_train, x1_train, t0_train, t1_train = x_train[:-1, :], x_train[1:, :], t_train[:-1, :], t_train[1:, :]
x0_test, x1_test, t0_test, t1_test = x_test[:-1, :], x_test[1:, :], t_test[:-1, :], t_test[1:, :]

Training

We can train a SympNet like any PyTorch module on the loss function defined as follows. Letting $\Phi_h^{\theta}(x)$ denote the SympNet, where $\theta$ denotes its set of trainable parameters, then we want to find $\theta$ that minimises

$\qquad loss=\sum_{i=0}^{n}|\Phi_h^{\theta}(x(ih))-x\left((i+1)h\right)|^2$

from sympnet.sympnet import SympNet

# Initialize Symplectic Neural Network
symp_net = SympNet(dim=1, layers=2, max_degree=2, method="P")

# Train it like any other PyTorch model
optimizer = torch.optim.Adam(symp_net.parameters(), lr=0.01)
mse = torch.nn.MSELoss()
for epoch in range(1000):
    optimizer.zero_grad()    
    x1_pred = symp_net(x=x0_train, dt=t1_train - t0_train)
    loss = mse(x1_train, x1_pred)
    loss.backward()
    optimizer.step()
print("final loss value: ", loss.item())

x1_test_pred = symp_net(x=x0_test, dt=t1_test - t0_test)
print("test set error", torch.norm(x1_test_pred - x1_test).item())

Outputs:

Final loss value:  2.1763008371575767e-33
test set error 5.992433957888383e-16

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

strupnet-0.0.1.tar.gz (13.8 kB view details)

Uploaded Source

Built Distribution

strupnet-0.0.1-py3-none-any.whl (16.7 kB view details)

Uploaded Python 3

File details

Details for the file strupnet-0.0.1.tar.gz.

File metadata

  • Download URL: strupnet-0.0.1.tar.gz
  • Upload date:
  • Size: 13.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.6

File hashes

Hashes for strupnet-0.0.1.tar.gz
Algorithm Hash digest
SHA256 6e979096ec4070862f9226b7848659386b789cf2822cb625326de8d71fba9727
MD5 80e0405a1c122a0450be3759376212ca
BLAKE2b-256 b18f8899986385b257a667edc324364ef1413e9eb3aea59394e8d226cb11df64

See more details on using hashes here.

File details

Details for the file strupnet-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: strupnet-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 16.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.6

File hashes

Hashes for strupnet-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 9ae20b0dfff89905a901794770032cead8d2fd9d37bdb46a70948333658dbf9e
MD5 43feeffbccaf3a6da98aba0867bdd00e
BLAKE2b-256 20d7fb93b1e4b1ccb81aea42babb68b0d703fb405293ec7f05c0ee703854d3c0

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page