Skip to main content

Didactic Gaussian processes in Jax.

Project description

GPJax's logo

codecov CodeFactor Documentation Status PyPI version DOI Downloads Slack Invite

Quickstart | Install guide | Documentation | Slack Community

GPJax aims to provide a low-level interface to Gaussian process (GP) models in Jax, structured to give researchers maximum flexibility in extending the code to suit their own needs. The idea is that the code should be as close as possible to the maths we write on paper when working with GP models.

Package support

GPJax was founded by Thomas Pinder. Today, the maintenance of GPJax is undertaken by Thomas Pinder and Daniel Dodd.

We would be delighted to receive contributions from interested individuals and groups. To learn how you can get involved, please read our guide for contributing. If you have any questions, we encourage you to open an issue. For broader conversations, such as best GP fitting practices or questions about the mathematics of GPs, we invite you to open a discussion.

Feel free to join our Slack Channel, where we can discuss the development of GPJax and broader support for Gaussian process modelling.

Supported methods and interfaces

Notebook examples

Guides for customisation

Conversion between .ipynb and .py

Above examples are stored in examples directory in the double percent (py:percent) format. Checkout jupytext using-cli for more info.

  • To convert example.py to example.ipynb, run:
jupytext --to notebook example.py
  • To convert example.ipynb to example.py, run:
jupytext --to py:percent example.ipynb

Simple example

Let us import some dependencies and simulate a toy dataset $\mathcal{D}$.

import gpjax as gpx
from jax import grad, jit
import jax.numpy as jnp
import jax.random as jr
import jaxkern as jk
import optax as ox

key = jr.PRNGKey(123)

f = lambda x: 10 * jnp.sin(x)

n = 50
x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,1)).sort()
y = f(x) + jr.normal(key, shape=(n,1))
D = gpx.Dataset(X=x, y=y)

The function of interest here, $f(\cdot)$, is sinusoidal, but our observations of it have been perturbed by Gaussian noise. We aim to utilise a Gaussian process to try and recover this latent function.

1. Constructing the prior and posterior

We begin by defining a zero-mean Gaussian process prior with a radial basis function kernel and assume the likelihood to be Gaussian.

prior = gpx.Prior(kernel = jk.RBF())
likelihood = gpx.Gaussian(num_datapoints = n)

Similar to how we would write on paper, the posterior is constructed by the product of our prior with our likelihood.

posterior = prior * likelihood

2. Learning hyperparameters

Equipped with the posterior, we seek to learn the model's hyperparameters through gradient-optimisation of the marginal log-likelihood. We this below, adding Jax's just-in-time (JIT) compilation to accelerate training.

mll = jit(posterior.marginal_log_likelihood(D, negative=True))

For purposes of optimisation, we'll use optax's Adam.

opt = ox.adam(learning_rate=1e-3)

We define an initial parameter state through the initialise callable.

parameter_state = gpx.initialise(posterior, key=key)

Finally, we run an optimisation loop using the Adam optimiser via the fit callable.

inference_state = gpx.fit(mll, parameter_state, opt, num_iters=500)

3. Making predictions

Using our learned hyperparameters, we can obtain the posterior distribution of the latent function at novel test points.

learned_params, _ = inference_state.unpack()
xtest = jnp.linspace(-3., 3., 100).reshape(-1, 1)

latent_distribution = posterior(learned_params, D)(xtest)
predictive_distribution = likelihood(learned_params, latent_distribution)

predictive_mean = predictive_distribution.mean()
predictive_cov = predictive_distribution.covariance()

Installation

Stable version

The latest stable version of GPJax can be installed via pip:

pip install gpjax

Note

We recommend you check your installation version:

python -c 'import gpjax; print(gpjax.__version__)'

Development version

Warning

This version is possibly unstable and may contain bugs.

Clone a copy of the repository to your local machine and run the setup configuration in development mode.

git clone https://github.com/JaxGaussianProcesses/GPJax.git
cd GPJax
python setup.py develop

Note

We advise you create virtual environment before installing:

conda create -n gpjax_experimental python=3.10.0
conda activate gpjax_experimental

and recommend you check your installation passes the supplied unit tests:

python -m pytest tests/

Citing GPJax

If you use GPJax in your research, please cite our JOSS paper.

@article{Pinder2022,
  doi = {10.21105/joss.04455},
  url = {https://doi.org/10.21105/joss.04455},
  year = {2022},
  publisher = {The Open Journal},
  volume = {7},
  number = {75},
  pages = {4455},
  author = {Thomas Pinder and Daniel Dodd},
  title = {GPJax: A Gaussian Process Framework in JAX},
  journal = {Journal of Open Source Software}
}

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gpjax-nightly-0.5.9.dev20230312.tar.gz (60.8 kB view details)

Uploaded Source

Built Distribution

gpjax_nightly-0.5.9.dev20230312-py3-none-any.whl (48.0 kB view details)

Uploaded Python 3

File details

Details for the file gpjax-nightly-0.5.9.dev20230312.tar.gz.

File metadata

File hashes

Hashes for gpjax-nightly-0.5.9.dev20230312.tar.gz
Algorithm Hash digest
SHA256 d39e6e35828b5c1d08dd431fdfaeacd9a4a8060e9b67ab3cf0dbd1bba00caa89
MD5 776c6be6a22aea03a422e15200347e87
BLAKE2b-256 8d697cbc659dcb06e08be38b2271682718a12c6aa1ecef785b151138da99037f

See more details on using hashes here.

File details

Details for the file gpjax_nightly-0.5.9.dev20230312-py3-none-any.whl.

File metadata

File hashes

Hashes for gpjax_nightly-0.5.9.dev20230312-py3-none-any.whl
Algorithm Hash digest
SHA256 822e62d6bab73dab241676cc42e2f5417ecbbefbc1f7f6f6e67525f9efd131d4
MD5 d62f5e92aa12e1f6be358438ddb5fb46
BLAKE2b-256 1aa2816095899c27c0db62c76a4f67731a368819c820e1c20ffc273741306e38

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page