Skip to main content

Lattice kernel for scalable Gaussian processes in GPyTorch

Project description

Simplex-GPs

This repository hosts the code for SKIing on Simplices: Kernel Interpolation on the Permutohedral Lattice for Scalable Gaussian Processes (Simplex-GPs) by Sanyam Kapoor, Marc Finzi, Ke Alexander Wang, Andrew Gordon Wilson.

The Idea

Fast matrix-vector multiplies (MVMs) are the cornerstone of modern scalable Gaussian processes. By building upon the approximation proposed by Structured Kernel Interpolation (SKI), and leveraging advances in fast high-dimensional image filtering, Simplex-GPs approximate the computation of the kernel matrices by tiling the space using a sparse permutohedral lattice, instead of a rectangular grid.

The matrix-vector product implied by the kernel operations in SKI are now approximated via the three stages visualized above --- splat (projection onto the permutohedral lattice), blur (applying the blur operation as a matrix-vector product), and slice (re-projecting back into the original space).

This alleviates the curse of dimensionality associated with SKI operations, allowing them to scale beyond ~5 dimensions, and provides competitive advantages in terms of runtime and memory costs, at little expense of downstream performance. See our manuscript for complete details.

Usage

The lattice kernels are packaged as GPyTorch modules, and can be used as a fast approximation to either the RBFKernel or the MaternKernel. The corresponding replacement modules are RBFLattice and MaternLattice.

RBFLattice kernel is simple to use by changing a single line of code:

import gpytorch as gp
from gpytorch_lattice_kernel import RBFLattice

class SimplexGPModel(gp.models.ExactGP):
  def __init__(self, train_x, train_y):
    likelihood = gp.likelihoods.GaussianLikelihood()
    super().__init__(train_x, train_y, likelihood)

    self.mean_module = gp.means.ConstantMean()
    self.covar_module = gp.kernels.ScaleKernel(
-      gp.kernels.RBFKernel(ard_num_dims=train_x.size(-1))
+      RBFLattice(ard_num_dims=train_x.size(-1), order=1)
    )

  def forward(self, x):
    mean_x = self.mean_module(x)
    covar_x = self.covar_module(x)
    return gp.distributions.MultivariateNormal(mean_x, covar_x)

The GPyTorch Regression Tutorial provides a simpler example on toy data, where this kernel can be used as a drop-in replacement.

Install

To use the kernel in your code, install the package as:

pip install gpytorch-lattice-kernel

NOTE: The kernel is compiled lazily from source using CMake. If the compilation fails, you may need to install a more recent version. Additionally, ninja is required for compilation. One way to install is:

conda install -c conda-forge cmake ninja

Local Setup

For a local development setup, create the conda environment

$ conda env create -f environment.yml

Remember to add the root of the project to PYTHONPATH if not already.

$ export PYTHONPATH="$(pwd):${PYTHONPATH}"

Test

To verify the code is working as expected, a simple test file is provided, that tests for the training marginal likelihood achieved by Simplex-GPs and Exact-GPs. Run as:

python tests/train_snelson.py

The Snelson 1-D toy dataset is used. A copy is available in snelson.csv.

Results

The proposed kernel can be used with GPyTorch as usual. An example script to reproduce results is,

python experiments/train_simplexgp.py --dataset=elevators --data-dir=<path/to/uci/data/mat/files>

We use Fire to handle CLI arguments. All arguments of the main function are therefore valid arguments to the CLI.

All figures in the paper can be reproduced via notebooks.

NOTE: The UCI dataset mat files are available here.

License

Apache 2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gpytorch-lattice-kernel-0.0.dev1.tar.gz (23.8 kB view details)

Uploaded Source

Built Distribution

gpytorch_lattice_kernel-0.0.dev1-py3-none-any.whl (23.7 kB view details)

Uploaded Python 3

File details

Details for the file gpytorch-lattice-kernel-0.0.dev1.tar.gz.

File metadata

  • Download URL: gpytorch-lattice-kernel-0.0.dev1.tar.gz
  • Upload date:
  • Size: 23.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.5.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.9.5

File hashes

Hashes for gpytorch-lattice-kernel-0.0.dev1.tar.gz
Algorithm Hash digest
SHA256 f0db94a17ed33eb3b68cb74e80ae966923284ba0b347ea4e0e4fb11c815e8fa0
MD5 be4f1bb4269b093ac82782e52004b20e
BLAKE2b-256 81d83b4367315cf8944bfb31ae7c53f59c0f36e67ac5789d0f927f5ff957b1fb

See more details on using hashes here.

File details

Details for the file gpytorch_lattice_kernel-0.0.dev1-py3-none-any.whl.

File metadata

  • Download URL: gpytorch_lattice_kernel-0.0.dev1-py3-none-any.whl
  • Upload date:
  • Size: 23.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.5.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.9.5

File hashes

Hashes for gpytorch_lattice_kernel-0.0.dev1-py3-none-any.whl
Algorithm Hash digest
SHA256 fe7eadcfa48aefecb0d310368c94c98a25776718491583d91268e6bbaf2fd977
MD5 f96aad3055257d4f7ddaa32e26cf7282
BLAKE2b-256 e59d8db1ee9db20b94a61e76de4e90ef2e3c8ba42fed16586f7cb0c0768c4581

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page