Skip to main content

Separable operator models for extreme-scale learning of parametric PDEs

Project description

Separable Operator Networks (SepONet)

Static Badge Static Badge

This is the official repository for separable operator networks (SepONet) originally introduced in this preprint [1].

Installation

This code uses JAX as a dependency. It is recommended to install with GPU/TPU compatibility prior to installing this library. JAX CPU is provided as the default dependency.

Please install with pip:

pip install separable-operator-networks

Alternatively, you may specify the [cuda12] extra to install jax[cuda12] automatically:

pip install separable-operator-networks[cuda12]

Description

Operator learning has become a powerful tool in machine learning for modeling complex physical systems governed by partial differential equations (PDEs). Although Deep Operator Networks (DeepONet) show promise, they require extensive data acquisition. Physics-informed DeepONets (PI-DeepONet) mitigate data scarcity but suffer from inefficient training processes. We introduce Separable Operator Networks (SepONet), a novel framework that significantly enhances the efficiency of physics-informed operator learning. SepONet uses independent trunk networks to learn basis functions separately for different coordinate axes, enabling faster and more memory-efficient training via forward-mode automatic differentiation. The SepONet architecture for a $d=2$ dimensional coordinate grid is depicted below. The architecture is inspired by the method of separation of variables and recent exploration of separable physics-informed neural networks [2] for single instance PDE solutions.

Our preprint provides a universal approximation theorem for SepONet proving that it generalizes to arbitrary operator learning problems. For a variety of 1D time-dependent PDEs, SepONet has similar accuracy scaling to PI-DeepONet, but with as much as 112x faster training time and 82x reduction in GPU memory usage. For 2D time-dependent PDEs, SepONet is capable of accurate predictions at scales where PI-DeepONet fails. The full test scaling results as a function of the number of collocation points and number of input functions is shown below. These results may be reproduced using our scripts.

SepONet architecture for 2 dimensional coordinate grid Comparing SepONet to PI-DeepONet when varying number of collocation points Comparing SepONet to PI-DeepONet when varying number of input functions

Code Overview

A SepONet model can be imported using:

import jax
import separable_operator_networks as sepop
d = ... # replace with problem dimension
branch_dim = ... # replace with input shape for branch network (MLP by default)
key = jax.random.key(0)

model = sepop.models.SepONet(d, branch_dim, key=key)

Other model classes such as PINN, SPINN, DeepONet are implemented in the sepop.models submodule. These models are implemented as subclasses of eqx.Module (see equinox), enabling eqx.filter_vmap and eqx.filter_grad, along with easily customizable training routines via optax (see sepop.train.train_loop(...) for a simple optax training loop). PDE instances, loss functions, and other helper functions can be imported from the corresponding examples in the sepop.pde submodule (such as sepop.pde.advection).

Test data can be generated using the Python scripts in /scripts/generate_test_data. Test cases can be ran using the scripts in /scripts/main_scripts and /scripts/scale_tests.

Citation

@misc{yu2024separableoperatornetworks,
title={Separable Operator Networks}, 
author={Xinling Yu and Sean Hooten and Ziyue Liu and Yequan Zhao and Marco Fiorentino and Thomas Van Vaerenbergh and Zheng Zhang},
year={2024},
eprint={2407.11253},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2407.11253}, 
}

Authors

Sean Hooten (sean dot hooten at hpe dot com)
Xinling Yu (xyu644 at ucsb dot edu)

License

MIT (see LICENSE.md)

References

[1] X. Yu, S. Hooten, Z. Liu, Y. Zhao, M. Fiorentino, T. Van Vaerenbergh, and Z. Zhang. Separable Operator Networks. arXiv preprint arXiv:2407.11253 (2024).
[2] J. Cho, S. Nam, H. Yang, S.-B. Yun, Y. Hong, E. Park. Separable PINN: Mitigating the Curse of Dimensionality in Physics-Informed Neural Networks. arXiv preprint arXiv: 2211.08761 (2023).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

separable_operator_networks-0.0.2.tar.gz (15.8 kB view details)

Uploaded Source

Built Distribution

File details

Details for the file separable_operator_networks-0.0.2.tar.gz.

File metadata

File hashes

Hashes for separable_operator_networks-0.0.2.tar.gz
Algorithm Hash digest
SHA256 83a097fe694f2f70db67de84fcaa0e134e88d04d5707524a2ebf75b15ef5d6fa
MD5 57d60ce1242ecac2feb8a71421d33a2a
BLAKE2b-256 deefb808e739f042f50c24bba057aad21bed013ed0f13d1044a89f0ffe6abf3d

See more details on using hashes here.

File details

Details for the file separable_operator_networks-0.0.2-py3-none-any.whl.

File metadata

File hashes

Hashes for separable_operator_networks-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 01b00d6fc8541173b98c8d548e0684f6457bd3139da1af7b84daae3eaee0e9b5
MD5 dc1e1edf3ca5932b16c6c66ee0cd1181
BLAKE2b-256 20f2891b29e523780491d60e950cf7ab4826ab11599e8ed27e7b0e74c444a689

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page