Quasigeostrophic model in JAX (port of PyQG)
Project description
PyQG JAX Port
This is a partial port of PyQG to JAX which enables GPU acceleration, batching, automatic differentiation, etc.
- Documentation: https://pyqg-jax.readthedocs.io/en/latest/
- Source Code: https://github.com/karlotness/pyqg-jax
- Bug Reports: https://github.com/karlotness/pyqg-jax/issues
⚠️ Warning: this is a partial, early stage port. There may be bugs and other numerical issues. The API may evolve as work continues.
Installation
Install from PyPI using pip:
$ python -m pip install pyqg-jax
or from conda-forge:
$ conda install -c conda-forge pyqg-jax
This should install required dependencies, but JAX itself may require special attention, particularly for GPU support. Follow the JAX installation instructions.
Usage
Documentation is a work in progress. The parameters QGModel
implemented here are the same as for the model in the original PyQG,
so consult the pyqg
documentation for details.
However, there are a few overarching changes used to make the models JAX-compatible:
-
The model state is now a separate, immutable object rather than being attributes of the
QGModel
class -
Time-stepping is now separated from the models. Use
steppers.AB3Stepper
for the same time stepping as in the originalQGModel
. -
Random initialization requires an explicit
key
variable as with all JAX random number generation.
The QGModel
uses double precision (float64
) values for part of its
computation regardless of the precision setting. Make sure JAX is set
to enable 64-bit. See the
documentation
for details. One option is to set the following environment variable:
export JAX_ENABLE_X64=True
or use the %env
magic
in a Jupyter notebook.
Short Example
A short example initializing a QGModel
, adding a parameterization,
and taking a single step (for more, see the
examples in
the documentation).
>>> import pyqg_jax
>>> import jax
>>> # Construct model, parameterization, and time-stepper
>>> stepped_model = pyqg_jax.steppers.SteppedModel(
... model=pyqg_jax.parameterizations.smagorinsky.apply_parameterization(
... pyqg_jax.qg_model.QGModel(),
... constant=0.08,
... ),
... stepper=pyqg_jax.steppers.AB3Stepper(dt=3600.0),
... )
>>> # Initialize the model state (wrapped in stepper and parameterization state)
>>> stepper_state = stepped_model.create_initial_state(
... jax.random.key(0)
... )
>>> # Compute next state
>>> next_stepper_state = stepped_model.step_model(stepper_state)
>>> # Unwrap the result from the stepper and parameterization
>>> next_param_state = next_stepper_state.state
>>> next_model_state = next_param_state.model_state
>>> final_q = next_model_state.q
For repeated time-stepping combine step_model
with
jax.lax.scan
.
License
This software is distributed under the MIT license. See LICENSE.txt for the license text.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file pyqg_jax-0.8.1.tar.gz
.
File metadata
- Download URL: pyqg_jax-0.8.1.tar.gz
- Upload date:
- Size: 35.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/4.0.2 CPython/3.11.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | cca2ff57429fbcdc259fde400d159c3f64a53569cd27d566f066bb42d1012839 |
|
MD5 | f66b96866df359d4ee0e1aec7a8ee253 |
|
BLAKE2b-256 | a95c94662bdeb36ff0decfd23e98d0f25627b95047218c020db8ce35c939187f |
File details
Details for the file pyqg_jax-0.8.1-py3-none-any.whl
.
File metadata
- Download URL: pyqg_jax-0.8.1-py3-none-any.whl
- Upload date:
- Size: 35.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/4.0.2 CPython/3.11.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b88f2fee6e5a41fc2c62ac26746da4e013eb41b3741af902c03243a43f93e196 |
|
MD5 | b1304ce155d545115c41cd265da99652 |
|
BLAKE2b-256 | 687e06039b3a11ee4aab983015f8501849afa59dce74694cfae91767ed7557a1 |