Skip to main content

Quasigeostrophic model in JAX (port of PyQG)

Project description

PyQG JAX Port

PyQG-JAX on PyPI PyQG-JAX on conda-forge Documentation Tests Zenodo

This is a partial port of PyQG to JAX which enables GPU acceleration, batching, automatic differentiation, etc.

⚠️ Warning: this is a partial, early stage port. There may be bugs and other numerical issues. The API may evolve as work continues.

Installation

Install from PyPI using pip:

$ python -m pip install pyqg-jax

or from conda-forge:

$ conda install -c conda-forge pyqg-jax

This should install required dependencies, but JAX itself may require special attention, particularly for GPU support. Follow the JAX installation instructions.

Usage

Documentation is a work in progress. The parameters QGModel implemented here are the same as for the model in the original PyQG, so consult the pyqg documentation for details.

However, there are a few overarching changes used to make the models JAX-compatible:

  1. The model state is now a separate, immutable object rather than being attributes of the QGModel class

  2. Time-stepping is now separated from the models. Use steppers.AB3Stepper for the same time stepping as in the original QGModel.

  3. Random initialization requires an explicit key variable as with all JAX random number generation.

The QGModel uses double precision (float64) values for part of its computation regardless of the precision setting. Make sure JAX is set to enable 64-bit. See the documentation for details. One option is to set the following environment variable:

export JAX_ENABLE_X64=True

or use the %env magic in a Jupyter notebook.

Short Example

A short example initializing a QGModel, adding a parameterization, and taking a single step (for more, see the examples in the documentation).

>>> import pyqg_jax
>>> import jax
>>> # Construct model, parameterization, and time-stepper
>>> stepped_model = pyqg_jax.steppers.SteppedModel(
...     model=pyqg_jax.parameterizations.smagorinsky.apply_parameterization(
...         pyqg_jax.qg_model.QGModel(),
...         constant=0.08,
...     ),
...     stepper=pyqg_jax.steppers.AB3Stepper(dt=3600.0),
... )
>>> # Initialize the model state (wrapped in stepper and parameterization state)
>>> stepper_state = stepped_model.create_initial_state(
...     jax.random.key(0)
... )
>>> # Compute next state
>>> next_stepper_state = stepped_model.step_model(stepper_state)
>>> # Unwrap the result from the stepper and parameterization
>>> next_param_state = next_stepper_state.state
>>> next_model_state = next_param_state.model_state
>>> final_q = next_model_state.q

For repeated time-stepping combine step_model with jax.lax.scan.

License

This software is distributed under the MIT license. See LICENSE.txt for the license text.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pyqg_jax-0.8.1.tar.gz (35.1 kB view details)

Uploaded Source

Built Distribution

pyqg_jax-0.8.1-py3-none-any.whl (35.0 kB view details)

Uploaded Python 3

File details

Details for the file pyqg_jax-0.8.1.tar.gz.

File metadata

  • Download URL: pyqg_jax-0.8.1.tar.gz
  • Upload date:
  • Size: 35.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/4.0.2 CPython/3.11.8

File hashes

Hashes for pyqg_jax-0.8.1.tar.gz
Algorithm Hash digest
SHA256 cca2ff57429fbcdc259fde400d159c3f64a53569cd27d566f066bb42d1012839
MD5 f66b96866df359d4ee0e1aec7a8ee253
BLAKE2b-256 a95c94662bdeb36ff0decfd23e98d0f25627b95047218c020db8ce35c939187f

See more details on using hashes here.

File details

Details for the file pyqg_jax-0.8.1-py3-none-any.whl.

File metadata

  • Download URL: pyqg_jax-0.8.1-py3-none-any.whl
  • Upload date:
  • Size: 35.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/4.0.2 CPython/3.11.8

File hashes

Hashes for pyqg_jax-0.8.1-py3-none-any.whl
Algorithm Hash digest
SHA256 b88f2fee6e5a41fc2c62ac26746da4e013eb41b3741af902c03243a43f93e196
MD5 b1304ce155d545115c41cd265da99652
BLAKE2b-256 687e06039b3a11ee4aab983015f8501849afa59dce74694cfae91767ed7557a1

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page