Skip to main content

Didactic Gaussian processes in Jax.

Project description

GPJax

Gaussian processes in Jax.

codecov CodeFactor Documentation Status Downloads

Quickstart | Install guide | Documentation

GPJax aims to provide a low-level interface to Gaussian process (GP) models in Jax, structured to give researchers maximum flexibility in extending the code to suit their own needs. We define a GP prior in GPJax by specifying a mean and kernel function and multiply this by a likelihood function to construct the posterior. The idea is that the code should be as close as possible to the maths we write on paper when working with GP models.

Supported methods and interfaces

Examples

Guides for customisation

Simple example

This simple regression example aims to illustrate the resemblance of GPJax's API with how we write the mathematics of Gaussian processes.

After importing the necessary dependencies, we'll simulate some data.

import gpjax as gpx
import jax
import jax.numpy as jnp
import jax.random as jr
from jax.experimental import optimizers
from jax import jit

key = jr.PRNGKey(123)

x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(50,)).sort().reshape(-1, 1)
y = jnp.sin(x) + jr.normal(key, shape=x.shape)*0.05
training = gpx.Dataset(X=x, y=y)

The function of interest here is sinusoidal, but our observations of it have been perturbed by independent zero-mean Gaussian noise. We aim to utilise a Gaussian process to try and recover this latent function.

We begin by defining a zero-mean Gaussian process prior with a radial basis function kernel and assume the likelihood to be Gaussian.

prior = gpx.Prior(kernel = gpx.RBF())
likelihood = gpx.Gaussian(num_datapoints = x.shape[0])

The posterior is then constructed by the product of our prior with our likelihood.

posterior = prior * likelihood

Equipped with the posterior, we proceed to train the model's hyperparameters through gradient-optimisation of the marginal log-likelihood.

We begin by defining a set of initial parameter values through the initialise callable.

params, _, constrainer, unconstrainer = gpx.initialise(posterior)
params = gpx.transform(params, unconstrainer)

Next, we define the marginal log-likelihood, adding Jax's just-in-time (JIT) compilation to accelerate training. Notice that this is the first instance of incorporating data into our model. Model building works this way in principle too, where we first define our prior model, then observe some data and use this data to build a posterior.

mll = jit(posterior.marginal_log_likelihood(training, constrainer, negative=True))

Finally, we utilise Jax's built-in Adam optimiser and run an optimisation loop.

opt_init, opt_update, get_params = optimizers.adam(step_size=0.01)
opt_state = opt_init(params)

def step(i, opt_state):
    params = get_params(opt_state)
    gradients = jax.grad(mll)(params)
    return opt_update(i, gradients, opt_state)

for i in range(100):
    opt_state = step(i, opt_state)

Now that our parameters are optimised, we transform these back to their original constrained space. Using their learned values, we can obtain the posterior distribution of the latent function at novel test points.

final_params = gpx.transform(get_params(opt_state), constrainer)

xtest = jnp.linspace(-3., 3., 100).reshape(-1, 1)

latent_distribution = posterior(training, final_params)(xtest)
predictive_distribution = likelihood(latent_distribution, params)

predictive_mean = predictive_distribution.mean()
predictive_stddev = predictive_distribution.stddev()

Installation

Stable version

To install the latest stable version of GPJax run

pip install gpjax

Development version

To install the latest, possibly unstable, version, the following steps should be followed. It is by no means compulsory, but we do advise that you do all of the below inside a virtual environment.

git clone https://github.com/thomaspinder/GPJax.git
cd GPJax
python setup.py develop

We then recommend you check your installation using the supplied unit tests.

python -m pytest tests/

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

GPJax-0.4.6.tar.gz (3.4 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

GPJax-0.4.6-py3-none-any.whl (25.3 kB view details)

Uploaded Python 3

File details

Details for the file GPJax-0.4.6.tar.gz.

File metadata

  • Download URL: GPJax-0.4.6.tar.gz
  • Upload date:
  • Size: 3.4 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.8.12

File hashes

Hashes for GPJax-0.4.6.tar.gz
Algorithm Hash digest
SHA256 f7e6c4b2e022de95221a67f02807c64e9f08549d2f6b0715aa638b994cd1ecfb
MD5 f3290cf18a3b307a50d5fea5aaa54b09
BLAKE2b-256 2b4f45f6ea48ec2e86df120e0e904a6332559f579fe020e41046f5fa9c4de850

See more details on using hashes here.

File details

Details for the file GPJax-0.4.6-py3-none-any.whl.

File metadata

  • Download URL: GPJax-0.4.6-py3-none-any.whl
  • Upload date:
  • Size: 25.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.8.12

File hashes

Hashes for GPJax-0.4.6-py3-none-any.whl
Algorithm Hash digest
SHA256 61f3e6668a52d3e540ebe8f2dc5909865b282f2690d53d6948f8a3e6fb204ef7
MD5 26415b3cb837a54f35896197ab9e4d64
BLAKE2b-256 c3423e25c60f739bd336a30ea0d699b3c414de1b1c30c2bc0978cad2d7acb619

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page