Skip to main content

Didactic Gaussian processes in Jax.

Project description

GPJax's logo

codecov CodeFactor Documentation Status PyPI version DOI Downloads Slack Invite

Quickstart | Install guide | Documentation | Slack Community

GPJax aims to provide a low-level interface to Gaussian process (GP) models in Jax, structured to give researchers maximum flexibility in extending the code to suit their own needs. We define a GP prior in GPJax by specifying a mean and kernel function and multiply this by a likelihood function to construct the posterior. The idea is that the code should be as close as possible to the maths we write on paper when working with GP models.

Package support

GPJax was created by Thomas Pinder. Today, the maintenance of GPJax is undertaken by Thomas and Daniel Dodd.

We would be delighted to review pull requests (PRs) from new contributors. Before contributing, please read our guide for contributing. If you do not have the capacity to open a PR, or you would like guidance on how best to structure a PR, then please open an issue. For broader discussions on best practices for fitting GPs, technical questions surrounding the mathematics of GPs, or anything else that you feel doesn't quite constitue an issue, please start a discussion thread in our discussion tracker.

We have recently set up a Slack channel where we hope to facilitate discussions around the development of GPJax and broader support for Gaussian process modelling. If you'd like to join the channel, then please follow this invitation link which will take you to the GPJax Slack community.

Supported methods and interfaces

Examples

Guides for customisation

Simple example

This simple regression example aims to illustrate the resemblance of GPJax's API with how we write the mathematics of Gaussian processes.

After importing the necessary dependencies, we'll simulate some data.

import gpjax as gpx
from jax import grad, jit
import jax.numpy as jnp
import jax.random as jr
import optax as ox

key = jr.PRNGKey(123)

x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(50,)).sort().reshape(-1, 1)
y = jnp.sin(x) + jr.normal(key, shape=x.shape)*0.05
training = gpx.Dataset(X=x, y=y)

The function of interest here is sinusoidal, but our observations of it have been perturbed by independent zero-mean Gaussian noise. We aim to utilise a Gaussian process to try and recover this latent function.

We begin by defining a zero-mean Gaussian process prior with a radial basis function kernel and assume the likelihood to be Gaussian.

prior = gpx.Prior(kernel = gpx.RBF())
likelihood = gpx.Gaussian(num_datapoints = x.shape[0])

The posterior is then constructed by the product of our prior with our likelihood.

posterior = prior * likelihood

Equipped with the posterior, we proceed to train the model's hyperparameters through gradient-optimisation of the marginal log-likelihood.

We begin by defining a set of initial parameter values through the initialise callable.

parameter_state = gpx.initialise(posterior, key=key)
params, trainables, constrainer, unconstrainer = parameter_state.unpack()
params = gpx.transform(params, unconstrainer)

Next, we define the marginal log-likelihood, adding Jax's just-in-time (JIT) compilation to accelerate training. Notice that this is the first instance of incorporating data into our model. Model building works this way in principle too, where we first define our prior model, then observe some data and use this data to build a posterior.

mll = jit(posterior.marginal_log_likelihood(training, constrainer, negative=True))

Finally, we utilise Jax's built-in Adam optimiser and run an optimisation loop.

opt = ox.adam(learning_rate=0.01)
opt_state = opt.init(params)

for _ in range(100):
  grads = grad(mll)(params)
  updates, opt_state = opt.update(grads, opt_state)
  params = ox.apply_updates(params, updates)

Now that our parameters are optimised, we transform these back to their original constrained space. Using their learned values, we can obtain the posterior distribution of the latent function at novel test points.

final_params = gpx.transform(params, constrainer)

xtest = jnp.linspace(-3., 3., 100).reshape(-1, 1)

latent_distribution = posterior(training, final_params)(xtest)
predictive_distribution = likelihood(latent_distribution, params)

predictive_mean = predictive_distribution.mean()
predictive_stddev = predictive_distribution.stddev()

Installation

Stable version

To install the latest stable version of GPJax run

pip install gpjax

Development version

To install the latest (possibly unstable) version, the following steps should be followed. It is by no means compulsory, but we do advise that you do all of the below inside a virtual environment.

git clone https://github.com/thomaspinder/GPJax.git
cd GPJax
python setup.py develop

We then recommend you check your installation using the supplied unit tests.

python -m pytest tests/

Citing GPJax

If you use GPJax in your research, please cite our JOSS paper. Sample Bibtex is given below:

@article{Pinder2022,
  doi = {10.21105/joss.04455},
  url = {https://doi.org/10.21105/joss.04455},
  year = {2022},
  publisher = {The Open Journal},
  volume = {7},
  number = {75},
  pages = {4455},
  author = {Thomas Pinder and Daniel Dodd},
  title = {GPJax: A Gaussian Process Framework in JAX},
  journal = {Journal of Open Source Software}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

GPJax-0.4.13.tar.gz (24.8 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page