probit is a simple and accessible Gaussian process implementation in Jax
Project description
probit
probit is a simple and accessible Gaussian process package in JAX. Thank you to nPlan, who are supporting this project.
probit uses MLKernels for the GP prior, see the available means and kernels with compositional design.
Contents:
TLDR:
>>> from probit.approximators import LaplaceGP as GP
>>> from probit.utilities import log_gaussian_likelihood
>>> from mlkernels import EQ
>>>
>>> def prior(prior_parameters):
>>> lengthscale, signal_variance = prior_parameters
>>> # Here you can define the kernel that defines the Gaussian process
>>> return signal_variance * EQ().stretch(lengthscale).periodic(0.5)
>>>
>>> gaussian_process = GP(data=(X, y), prior=prior, log_likelihood=log_gaussian_likelihood)
>>> likelihood_parameters = 1.0
>>> prior_parameters = (1.0, 1.0)
>>> parameters = (prior_parameters, likelihood_parameters)
>>> weight, precision = gaussian_process.approximate_posterior(parameters)
>>> predictive_mean, predictive_variance = gaussian_process.predict(
>>> X_test,
>>> parameters, weight, precision)
Installation
The package requires Python 3.8+. First, it is recommended to create a new python virtual environment.
probit depends on JAX. Because the JAX installation is different depending on your CUDA version, probit does not list JAX as a dependency in setup.py
.
First, follow these instructions to install JAX with the relevant accelerator support.
Then, pip install probit
or for developers,
- Clone the repository
git clone git@github.com:bb515/probit.git
- Install using pip
pip install -e .
from the root directory of the repository (see thesetup.py
for the requirements that this command installs)
Examples
You can find examples of how to use the package under:examples/
.
Regression and hyperparameter optimization
Run the regression example by typing python examples/regression.py
.
>>> def prior(prior_parameters):
>>> lengthscale, signal_variance = prior_parameters
>>> # Here you can define the kernel that defines the Gaussian process
>>> return signal_variance * EQ().stretch(lengthscale).periodic(0.5)
>>>
>>> # Generate data
>>> key = random.PRNGKey(0)
>>> noise_std = 0.2
>>> (X, y, X_show, f_show, N_show) = generate_data(
>>> key, N_train=20,
>>> kernel=prior((1.0, 1.0)), noise_std=noise_std,
>>> N_show=1000)
>>>
>>> gaussian_process = GP(data=(X, y), prior=prior, log_likelihood=log_gaussian_likelihood)
>>> evidence = gaussian_process.objective()
>>>
>>> vs = Vars(jnp.float32)
>>>
>>> def model(vs):
>>> p = vs.struct
>>> return (p.lengthscale.positive(), p.signal_variance.positive()), (p.noise_std.positive(),)
>>>
>>> def objective(vs):
>>> return evidence(model(vs))
>>>
>>> # Approximate posterior
>>> parameters = model(vs)
>>> weight, precision = gaussian_process.approximate_posterior(parameters)
>>> mean, variance = gaussian_process.predict(
>>> X_show, parameters, weight, precision)
>>> noise_variance = vs.struct.noise_std()**2
>>> obs_variance = variance + noise_variance
>>> plot((X, y), (X_show, f_show), mean, variance, fname="readme_regression_before.png")
>>> print("Before optimization, \nparams={}".format(parameters))
Before optimization, params=((Array(0.10536897, dtype=float32), Array(0.2787192, dtype=float32)), (Array(0.6866876, dtype=float32),))
>>> minimise_l_bfgs_b(objective, vs)
>>> parameters = model(vs)
>>> print("After optimization, \nparams={}".format(parameters))
After optimization, params=((Array(1.354531, dtype=float32), Array(0.48594338, dtype=float32)), (Array(0.1484054, dtype=float32),))
>>> # Approximate posterior
>>> weight, precision = gaussian_process.approximate_posterior(parameters)
>>> mean, variance = gaussian_process.predict(
>>> X_show, parameters, weight, precision)
>>> noise_variance = vs.struct.noise_std()**2
>>> obs_variance = variance + noise_variance
>>> plot((X, y), (X_show, f_show), mean, obs_variance, fname="readme_regression_after.png")
Ordinal regression and hyperparameter optimization
Run the ordinal regression example by typing python examples/classification.py
.
>>> # Generate data
>>> J = 3 # use a value of J=2 for GP binary classification
>>> key = random.PRNGKey(1)
>>> noise_variance = 0.4
>>> signal_variance = 1.0
>>> lengthscale = 1.0
>>> kernel = signal_variance * Matern12().stretch(lengthscale)
>>> (N_show, X, g_true, y, cutpoints,
>>> X_test, y_test,
>>> X_show, f_show) = generate_data(key,
>>> N_train_per_class=10, N_test_per_class=100,
>>> J=J, kernel=kernel, noise_variance=noise_variance,
>>> N_show=1000, jitter=1e-6)
>>>
>>> # Initiate a misspecified model, using a kernel
>>> # other than the one used to generate data
>>> def prior(prior_parameters):
>>> # Here you can define the kernel that defines the Gaussian process
>>> return signal_variance * EQ().stretch(prior_parameters)
>>>
>>> classifier = Approximator(data=(X, y), prior=prior,
>>> log_likelihood=log_probit_likelihood,
>>> tolerance=1e-5 # tolerance for the jaxopt fixed-point resolution
>>> )
>>> negative_evidence_lower_bound = classifier.objective()
>>>
>>> vs = Vars(jnp.float32)
>>>
>>> def model(vs):
>>> p = vs.struct
>>> noise_std = jnp.sqrt(noise_variance)
>>> return (p.lengthscale.positive(1.2)), (noise_std, cutpoints)
>>>
>>> def objective(vs):
>>> return negative_evidence_lower_bound(model(vs))
>>>
>>> # Approximate posterior
>>> parameters = model(vs)
>>> weight, precision = classifier.approximate_posterior(parameters)
>>> mean, variance = classifier.predict(
>>> X_show,
>>> parameters,
>>> weight, precision)
>>> obs_variance = variance + noise_variance
>>> predictive_distributions = probit_predictive_distributions(
>>> parameters[1],
>>> mean, variance)
>>> plot(X_show, predictive_distributions, mean,
>>> obs_variance, X_show, f_show, X, y, g_true,
>>> J, colors, fname="readme_classification_before")
>>> # Evaluate model
>>> mean, variance = classifier.predict(
>>> X_test,
>>> parameters,
>>> weight, precision)
>>> predictive_distributions = probit_predictive_distributions(
>>> parameters[1],
>>> mean, variance)
>>> print("\nEvaluation of model:")
>>> calculate_metrics(y_test, predictive_distributions)
>>> print("Before optimization, \nparameters={}".format(parameters))
Evaluation of model:
116 sum incorrect
184 sum correct
mean_absolute_error=0.41
log_pred_probability=-140986.54
mean_zero_one_error=0.39
Before optimization, parameters=(Array(1.2, dtype=float32), (Array(0.63245553, dtype=float64, weak_type=True), Array([ -inf, -0.54599167, 0.50296235, inf], dtype=float64)))
>>> minimise_l_bfgs_b(objective, vs)
>>> parameters = model(vs)
>>> print("After optimization, \nparameters={}".format(model(vs)))
After optimization, parameters=(Array(0.07389855, dtype=float32), (Array(0.63245553, dtype=float64, weak_type=True), Array([ -inf, -0.54599167, 0.50296235, inf], dtype=float64)))
>>> # Approximate posterior
>>> parameters = model(vs)
>>> weight, precision = classifier.approximate_posterior(parameters)
>>> mean, variance = classifier.predict(
>>> X_show,
>>> parameters,
>>> weight, precision)
>>> predictive_distributions = probit_predictive_distributions(
>>> parameters[1],
>>> mean, variance)
>>> plot(X_show, predictive_distributions, mean,
>>> obs_variance, X_show, f_show, X, y, g_true,
>>> J, colors, fname="readme_classification_after")
>>> # Evaluate model
>>> mean, variance = classifier.predict(
>>> X_test,
>>> parameters,
>>> weight, precision)
>>> obs_variance = variance + noise_variance
>>> predictive_distributions = probit_predictive_distributions(
>>> parameters[1],
>>> mean, variance)
>>> print("\nEvaluation of model:")
>>> calculate_metrics(y_test, predictive_distributions)
Evaluation of model:
106 sum incorrect
194 sum correct
mean_absolute_error=0.36
log_pred_probability=-161267.49
mean_zero_one_error=0.35
>>> nelbo = lambda x : negative_evidence_lower_bound(((x), (jnp.sqrt(noise_variance), cutpoints)))
>>> fg = vmap(value_and_grad(nelbo))
>>>
>>> domain = ((-2, 2), None)
>>> resolution = (50, None)
>>> x = jnp.logspace(
>>> domain[0][0], domain[0][1], resolution[0])
>>> xlabel = r"lengthscale, $\ell$"
>>> xscale = "log"
>>> phis = jnp.log(x)
>>>
>>> fgs = fg(x)
>>> fs = fgs[0]
>>> gs = fgs[1]
>>> plot_obj(vs.struct.lengthscale(), lengthscale, x, fs, gs, domain, xlabel, xscale)
Doesn't haves
References
Algorithms in this package were ported from pre-existing code. In particular, the code was ported from the following papers and repositories:
Laplace approximation http://www.gatsby.ucl.ac.uk/~chuwei/ordinalregression.html\
@article{Chu2005,
author = {Chu, Wei and Ghahramani, Zoubin},
year = {2005},
month = {07},
pages = {1019-1041},
title = {Gaussian Processes for Ordinal Regression.},
volume = {6},
journal = {Journal of Machine Learning Research},
howpublished = {\url{http://www.gatsby.ucl.ac.uk/~chuwei/ordinalregression.html}}}
Variational inference via factorizing assumption and free form minimization
@article{Girolami2005,
author="M. Girolami and S. Rogers",
journal="Neural Computation",
title="Variational Bayesian Multinomial Probit Regression with Gaussian Process Priors",
year="2006",
volume="18",
number="8",
pages="1790-1817"}
and
@Misc{King2005,
title = {Variational Inference in Gaussian Processes via Probabilistic Point Assimilation},
author = {King, Nathaniel J. and Lawrence, Neil D.},
year = {2005},
number = {CS-05-06},
url = {http://inverseprobability.com/publications/king-ppa05.html}}
An implicit functions tutorial was used to define the fixed-point layer.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.