Separable operator models for extreme-scale learning of parametric PDEs
Project description
Separable Operator Networks (SepONet)
This is the official repository for separable operator networks (SepONet) originally introduced in this preprint [1].
Installation
This code uses JAX as a dependency. It is recommended to install with GPU/TPU compatibility prior to installing this library. JAX CPU is provided as the default dependency.
Please install with pip:
pip install separable-operator-networks
Alternatively, you may specify the [cuda12]
extra to install jax[cuda12]
automatically:
pip install separable-operator-networks[cuda12]
Description
Operator learning has become a powerful tool in machine learning for modeling complex physical systems governed by partial differential equations (PDEs). Although Deep Operator Networks (DeepONet) show promise, they require extensive data acquisition. Physics-informed DeepONets (PI-DeepONet) mitigate data scarcity but suffer from inefficient training processes. We introduce Separable Operator Networks (SepONet), a novel framework that significantly enhances the efficiency of physics-informed operator learning. SepONet uses independent trunk networks to learn basis functions separately for different coordinate axes, enabling faster and more memory-efficient training via forward-mode automatic differentiation. The SepONet architecture for a $d=2$ dimensional coordinate grid is depicted below. The architecture is inspired by the method of separation of variables and recent exploration of separable physics-informed neural networks [2] for single instance PDE solutions.
Our preprint provides a universal approximation theorem for SepONet proving that it generalizes to arbitrary operator learning problems. For a variety of 1D time-dependent PDEs, SepONet has similar accuracy scaling to PI-DeepONet, but with as much as 112x faster training time and 82x reduction in GPU memory usage. For 2D time-dependent PDEs, SepONet is capable of accurate predictions at scales where PI-DeepONet fails. The full test scaling results as a function of the number of collocation points and number of input functions is shown below. These results may be reproduced using our scripts.
Code Overview
A SepONet model can be imported using:
import jax
import separable_operator_networks as sepop
d = ... # replace with problem dimension
branch_dim = ... # replace with input shape for branch network (MLP by default)
key = jax.random.key(0)
model = sepop.models.SepONet(d, branch_dim, key=key)
Other model classes such as PINN
, SPINN
, DeepONet
are implemented in the sepop.models
submodule. These models are implemented as subclasses of eqx.Module
(see equinox), enabling eqx.filter_vmap
and eqx.filter_grad
, along with easily customizable training routines via optax (see sepop.train.train_loop(...)
for a simple optax
training loop). PDE instances, loss functions, and other helper functions can be imported from the corresponding examples in the sepop.pde
submodule (such as sepop.pde.advection
).
Test data can be generated using the Python scripts in /scripts/generate_test_data
. Test cases can be ran using the scripts in /scripts/main_scripts
and /scripts/scale_tests
.
Citation
@misc{yu2024separableoperatornetworks,
title={Separable Operator Networks},
author={Xinling Yu and Sean Hooten and Ziyue Liu and Yequan Zhao and Marco Fiorentino and Thomas Van Vaerenbergh and Zheng Zhang},
year={2024},
eprint={2407.11253},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2407.11253},
}
Authors
Sean Hooten (sean dot hooten at hpe dot com)
Xinling Yu (xyu644 at ucsb dot edu)
License
MIT (see LICENSE.md)
References
[1] X. Yu, S. Hooten, Z. Liu, Y. Zhao, M. Fiorentino, T. Van Vaerenbergh, and Z. Zhang. Separable Operator Networks. arXiv preprint arXiv:2407.11253 (2024).
[2] J. Cho, S. Nam, H. Yang, S.-B. Yun, Y. Hong, E. Park. Separable PINN: Mitigating the Curse of Dimensionality in Physics-Informed Neural Networks. arXiv preprint arXiv: 2211.08761 (2023).
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file separable_operator_networks-0.0.2.tar.gz
.
File metadata
- Download URL: separable_operator_networks-0.0.2.tar.gz
- Upload date:
- Size: 15.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.13.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 83a097fe694f2f70db67de84fcaa0e134e88d04d5707524a2ebf75b15ef5d6fa |
|
MD5 | 57d60ce1242ecac2feb8a71421d33a2a |
|
BLAKE2b-256 | deefb808e739f042f50c24bba057aad21bed013ed0f13d1044a89f0ffe6abf3d |
File details
Details for the file separable_operator_networks-0.0.2-py3-none-any.whl
.
File metadata
- Download URL: separable_operator_networks-0.0.2-py3-none-any.whl
- Upload date:
- Size: 24.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.13.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 01b00d6fc8541173b98c8d548e0684f6457bd3139da1af7b84daae3eaee0e9b5 |
|
MD5 | dc1e1edf3ca5932b16c6c66ee0cd1181 |
|
BLAKE2b-256 | 20f2891b29e523780491d60e950cf7ab4826ab11599e8ed27e7b0e74c444a689 |