Skip to main content

probit is a simple and accessible Gaussian process implementation in Jax.

Project description

probit

CI Coverage Status Code style: black

probit is a simple and accessible Gaussian process package in Jax.

probit uses MLKernels for the GP prior, see the available means and kernels with compositional design.

Contents:

TLDR:

>>> from probit.approximators import LaplaceGP as GP
>>> from probit.utilities import log_gaussian_likelihood
>>> from mlkernels import EQ
>>>
>>> def prior(prior_parameters):
>>>     lengthscale, signal_variance = prior_parameters
>>>     # Here you can define the kernel that defines the Gaussian process
>>>     return signal_variance * EQ().stretch(lengthscale).periodic(0.5)
>>>
>>> gaussian_process = GP(data=(X, y), prior=prior, log_likelihood=log_gaussian_likelihood)
>>> likelihood_parameters = 1.0
>>> prior_parameters = (1.0, 1.0)
>>> parameters = (likelihood_parameters, prior_parameters)
>>> weight, precision = gaussian_process.approximate_posterior(parameters)
>>> predictive_mean, predictive_variance = gaussian_process.predict(
>>>     X_test,
>>>     parameters, weight, precision)

Installation

pip install probit or for developers,

  • Clone the repository git clone git@github.com:bb515/probit.git
  • Install using pip pip install -e . from the root directory of the repository (see the setup.py for the requirements that this command installs)

Examples

You can find examples of how to use the package under:examples/.

Regression and hyperparameter optimization

Run the regression example by typing python example/regression.py.

>>> def prior(prior_parameters):
>>>     lengthscale, signal_variance = prior_parameters
>>>     # Here you can define the kernel that defines the Gaussian process
>>>     return signal_variance * EQ().stretch(lengthscale).periodic(0.5)
>>>
>>> # Generate data
>>> key = random.PRNGKey(0)
>>> noise_std = 0.2
>>> (X, y, X_show, f_show, N_show) = generate_data(
>>>     key, N_train=20,
>>>     kernel=prior((1.0, 1.0)), noise_std=noise_std,
>>>     N_show=1000)
>>>
>>> gaussian_process = GP(data=(X, y), prior=prior, log_likelihood=log_gaussian_likelihood)
>>> evidence = gaussian_process.objective()
>>>
>>> vs = Vars(jnp.float32)
>>>
>>> def model(vs):
>>>     p = vs.struct
>>>     return (p.lengthscale.positive(), p.signal_variance.positive()), (p.noise_std.positive(),)
>>>
>>> def objective(vs):
>>>     return evidence(model(vs))
>>>
>>> # Approximate posterior
>>> parameters = model(vs)
>>> weight, precision = gaussian_process.approximate_posterior(parameters)
>>> mean, variance = gaussian_process.predict(
>>>     X_show, parameters, weight, precision)
>>> noise_variance = vs.struct.noise_std()**2
>>> obs_variance = variance + noise_variance
>>> plot((X, y), (X_show, f_show), mean, variance, fname="readme_regression_before.png")

Prediction

>>> print("Before optimization, \nparams={}".format(parameters))

Before optimization, params=((Array(0.10536897, dtype=float32), Array(0.2787192, dtype=float32)), (Array(0.6866876, dtype=float32),))

>>> minimise_l_bfgs_b(objective, vs)
>>> parameters = model(vs)
>>> print("After optimization, \nparams={}".format(parameters))

After optimization, params=((Array(1.354531, dtype=float32), Array(0.48594338, dtype=float32)), (Array(0.1484054, dtype=float32),))

>>> # Approximate posterior
>>> weight, precision = gaussian_process.approximate_posterior(parameters)
>>> mean, variance = gaussian_process.predict(
>>>     X_show, parameters, weight, precision)
>>> noise_variance = vs.struct.noise_std()**2
>>> obs_variance = variance + noise_variance
>>> plot((X, y), (X_show, f_show), mean, obs_variance, fname="readme_regression_after.png")

Prediction

Ordinal regression and hyperparameter optimization

Run the ordinal regression example by typing python example/classification.py.

>>> # Generate data
>>> J = 3  # use a value of J=2 for GP binary classification
>>> key = random.PRNGKey(1)
>>> noise_variance = 0.4
>>> signal_variance = 1.0
>>> lengthscale = 1.0
>>> kernel = signal_variance * Matern12().stretch(lengthscale)
>>> (N_show, X, g_true, y, cutpoints,
>>> X_test, y_test,
>>> X_show, f_show) = generate_data(key,
>>>     N_train_per_class=10, N_test_per_class=100,
>>>     J=J, kernel=kernel, noise_variance=noise_variance,
>>>     N_show=1000, jitter=1e-6)
>>>
>>> # Initiate a misspecified model, using a kernel
>>> # other than the one used to generate data
>>> def prior(prior_parameters):
>>>     # Here you can define the kernel that defines the Gaussian process
>>>     return signal_variance * EQ().stretch(prior_parameters)
>>>
>>> classifier = Approximator(data=(X, y), prior=prior,
>>>     log_likelihood=log_probit_likelihood,
>>>     tolerance=1e-5  # tolerance for the jaxopt fixed-point resolution
>>>     )
>>> negative_evidence_lower_bound = classifier.objective()
>>>
>>> vs = Vars(jnp.float32)
>>>
>>> def model(vs):
>>>     p = vs.struct
>>>     noise_std = jnp.sqrt(noise_variance)
>>>     return (p.lengthscale.positive(1.2)), (noise_std, cutpoints)
>>>
>>> def objective(vs):
>>>     return negative_evidence_lower_bound(model(vs))
>>>
>>> # Approximate posterior
>>> parameters = model(vs)
>>> weight, precision = classifier.approximate_posterior(parameters)
>>> mean, variance = classifier.predict(
>>>     X_show,
>>>     parameters,
>>>     weight, precision)
>>> obs_variance = variance + noise_variance
>>> predictive_distributions = probit_predictive_distributions(
>>>     parameters[1],
>>>     mean, variance)
>>> plot(X_show, predictive_distributions, mean,
>>>     obs_variance, X_show, f_show, X, y, g_true,
>>>     J, colors, fname="readme_classification_before")

Prediction Prediction

>>>
>>> # Evaluate model
>>> mean, variance = classifier.predict(
>>>     X_test,
>>>     parameters,
>>>     weight, precision)
>>> predictive_distributions = probit_predictive_distributions(
>>>     parameters[1],
>>>     mean, variance)
>>> print("\nEvaluation of model:")
>>> calculate_metrics(y_test, predictive_distributions)
>>> print("Before optimization, \nparameters={}".format(parameters))

Evaluation of model:
116 sum incorrect
184 sum correct
mean_absolute_error=0.41
log_pred_probability=-140986.54
mean_zero_one_error=0.39

Before optimization, parameters=(Array(1.2, dtype=float32), (Array(0.63245553, dtype=float64, weak_type=True), Array([ -inf, -0.54599167, 0.50296235, inf], dtype=float64)))

>>>
>>> minimise_l_bfgs_b(objective, vs)
>>> parameters = model(vs)
>>> print("After optimization, \nparameters={}".format(model(vs)))

After optimization, parameters=(Array(0.07389855, dtype=float32), (Array(0.63245553, dtype=float64, weak_type=True), Array([ -inf, -0.54599167, 0.50296235, inf], dtype=float64)))

>>>
>>> # Approximate posterior
>>> parameters = model(vs)
>>> weight, precision = classifier.approximate_posterior(parameters)
>>> mean, variance = classifier.predict(
>>>     X_show,
>>>     parameters,
>>>     weight, precision)
>>> predictive_distributions = probit_predictive_distributions(
>>>     parameters[1],
>>>     mean, variance)
>>> plot(X_show, predictive_distributions, mean,
>>>     obs_variance, X_show, f_show, X, y, g_true,
>>>     J, colors, fname="readme_classification_after")

Prediction Prediction

>>> # Evaluate model
>>> mean, variance = classifier.predict(
>>>     X_test,
>>>     parameters,
>>>     weight, precision)
>>> obs_variance = variance + noise_variance
>>> predictive_distributions = probit_predictive_distributions(
>>>     parameters[1],
>>>     mean, variance)
>>> print("\nEvaluation of model:")
>>> calculate_metrics(y_test, predictive_distributions)

Evaluation of model:
106 sum incorrect
194 sum correct
mean_absolute_error=0.36
log_pred_probability=-161267.49
mean_zero_one_error=0.35

>>> nelbo = lambda x : negative_evidence_lower_bound(((x), (jnp.sqrt(noise_variance), cutpoints)))
>>> fg = vmap(value_and_grad(nelbo))
>>>
>>> domain = ((-2, 2), None)
>>> resolution = (50, None)
>>> x = jnp.logspace(
>>>     domain[0][0], domain[0][1], resolution[0])
>>> xlabel = r"lengthscale, $\ell$"
>>> xscale = "log"
>>> phis = jnp.log(x)
>>>
>>> fgs = fg(x)
>>> fs = fgs[0]
>>> gs = fgs[1]
>>> plot_obj(vs.struct.lengthscale(), lengthscale, x, fs, gs, domain, xlabel, xscale)

Prediction Prediction

Doesn't haves

References

Algorithms in this package were ported from pre-existing code. In particular, the code was ported from the following papers and repositories:

Laplace approximation http://www.gatsby.ucl.ac.uk/~chuwei/ordinalregression.html\ @article{Chu2005,
author = {Chu, Wei and Ghahramani, Zoubin},
year = {2005},
month = {07},
pages = {1019-1041},
title = {Gaussian Processes for Ordinal Regression.},
volume = {6},
journal = {Journal of Machine Learning Research},
howpublished = {\url{http://www.gatsby.ucl.ac.uk/~chuwei/ordinalregression.html}}}

Variational inference via factorizing assumption and free form minimization
@article{Girolami2005,
author="M. Girolami and S. Rogers",
journal="Neural Computation",
title="Variational Bayesian Multinomial Probit Regression with Gaussian Process Priors",
year="2006",
volume="18",
number="8",
pages="1790-1817"}
and
@Misc{King2005,
title = {Variational Inference in Gaussian Processes via Probabilistic Point Assimilation},
author = {King, Nathaniel J. and Lawrence, Neil D.},
year = {2005},
number = {CS-05-06},
url = {http://inverseprobability.com/publications/king-ppa05.html}}

An implicit functions tutorial was used to define the fixed-point layer.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

probit-0.0.0.tar.gz (423.4 kB view details)

Uploaded Source

File details

Details for the file probit-0.0.0.tar.gz.

File metadata

  • Download URL: probit-0.0.0.tar.gz
  • Upload date:
  • Size: 423.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.1

File hashes

Hashes for probit-0.0.0.tar.gz
Algorithm Hash digest
SHA256 bd60881970d0b2cc867e299f4ebc0cecd10869a0a1d64accab02d717fca421e6
MD5 cbf3ecb20673dacaaec07514cc0b0a46
BLAKE2b-256 41b2f133c75578c2f27045f793e4f02cae440b4656211a32e66017e0bc8299ec

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page