Physics-Informed Neural Networks for Scientific Machine Learning in JAX.
Project description
PINNx: Physics-Informed Neural Networks for Scientific Machine Learning in JAX
PINNx is a library for scientific machine learning and physics-informed learning in JAX.
It is a rewrite of DeepXDE but is enhanced by our
brain modeling ecosystem.
For example, it leverages
- brainstate for just-in-time compilation,
- brainunit for dimensional analysis,
- braintools for checkpointing, loss functions, and other utilities.
Quickstart
Define a PINN with explicit variables and physical units.
import braintools
import brainunit as u
import pinnx
# geometry
geometry = pinnx.geometry.GeometryXTime(
geometry=pinnx.geometry.Interval(-1, 1.),
timedomain=pinnx.geometry.TimeDomain(0, 0.99)
).to_dict_point(x=u.meter, t=u.second)
uy = u.meter / u.second
v = 0.01 / u.math.pi * u.meter ** 2 / u.second
# boundary conditions
bc = pinnx.icbc.DirichletBC(lambda x: {'y': 0. * uy})
ic = pinnx.icbc.IC(lambda x: {'y': -u.math.sin(u.math.pi * x['x'] / u.meter) * uy})
# PDE equation
def pde(x, y):
jacobian = approximator.jacobian(x)
hessian = approximator.hessian(x)
dy_x = jacobian['y']['x']
dy_t = jacobian['y']['t']
dy_xx = hessian['y']['x']['x']
residual = dy_t + y['y'] * dy_x - v * dy_xx
return residual
# neural network
approximator = pinnx.nn.Model(
pinnx.nn.DictToArray(x=u.meter, t=u.second),
pinnx.nn.FNN(
[geometry.dim] + [20] * 3 + [1],
"tanh",
braintools.init.KaimingUniform()
),
pinnx.nn.ArrayToDict(y=uy)
)
# problem
problem = pinnx.problem.TimePDE(
geometry,
pde,
[bc, ic],
approximator,
num_domain=2540,
num_boundary=80,
num_initial=160,
)
# training
trainer = pinnx.Trainer(problem)
trainer.compile(braintools.optim.Adam(1e-3)).train(iterations=15000)
trainer.compile(braintools.optim.LBFGS(1e-3)).train(2000, display_every=500)
trainer.saveplot(issave=True, isplot=True)
Installation
- Install the stable version with
pip:
pip install pinnx --upgrade
- Install
pinnxon CPU or GPU with JAX following the instructions on
pip install pinnx[cpu] # for CPU
pip install pinnx[cuda12] # for NVIDIA GPUs with CUDA 12
pip install pinnx[cuda13] # for NVIDIA GPUs with CUDA 13
pip install pinnx[tpu] # for Google TPUs
Documentation
The official documentation is hosted on Read the Docs: https://pinnx.readthedocs.io/
See also the ecosystem
pinnx is one part of our brain modeling ecosystem: https://brainmodeling.readthedocs.io/
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pinnx-0.0.3.tar.gz.
File metadata
- Download URL: pinnx-0.0.3.tar.gz
- Upload date:
- Size: 96.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
755dab2973ce5c0f61cefe220c1aca9d9cb710192002aad3f8614705ae26b97f
|
|
| MD5 |
474b2165fa06aa0e7947a7f89a1a2faa
|
|
| BLAKE2b-256 |
51e52188f155e9e0c8d0c19cc47a4e94b93c6f70b749d9d689c9be122e8f4a6a
|
File details
Details for the file pinnx-0.0.3-py3-none-any.whl.
File metadata
- Download URL: pinnx-0.0.3-py3-none-any.whl
- Upload date:
- Size: 122.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a0e120f15a8f6847941b12ffcd063c8649e6538cb8396781cb8ef5685445337a
|
|
| MD5 |
85ff30a01ac33c6057e4a1940b936b61
|
|
| BLAKE2b-256 |
a321612b69bdb361a70e60e060387989e4480ee9f496c0314c797f399db9f9f1
|