example description
Project description
skgpytorch
GPyTorch Models in Scikit-learn wrapper.
Example
import torch
from skgpytorch.models import ExactGPRegressor
from skgpytorch.metrics import mean_squared_error, negative_log_predictive_density
from gpytorch.kernels import RBFKernel, ScaleKernel
# Define a model
train_x = torch.rand(10, 1)
train_y = torch.rand(10)
test_x = torch.rand(10, 1)
test_y = torch.rand(10)
kernel = ScaleKernel(RBFKernel(ard_num_dims=train_x.shape[1]))
gp = ExactGPRegressor(train_x, train_y, kernel, random_state=0, device="cpu")
# Fit the model
gp.fit(n_iters=10, verbose=True, n_restarts=2, verbose_gap=2)
# Get the predictions
# f_mean, f_var = gp.predict(test_x)
# OR
pred_dist = gp.predict(test_x, dist_only=True)
# Calculate metrics
print("MSE:", mean_squared_error(pred_dist, test_x, test_y))
print("NLPD:", negative_log_predictive_density(pred_dist, test_x, test_y))
Restart: 0, Iter: 0, Loss: 1.0135, Best Loss: inf
Restart: 0, Iter: 2, Loss: 0.9371, Best Loss: inf
Restart: 0, Iter: 4, Loss: 0.8644, Best Loss: inf
Restart: 0, Iter: 6, Loss: 0.7978, Best Loss: inf
Restart: 0, Iter: 8, Loss: 0.7382, Best Loss: inf
Restart: 1, Iter: 0, Loss: 0.9626, Best Loss: 0.6819
Restart: 1, Iter: 2, Loss: 0.8948, Best Loss: 0.6819
Restart: 1, Iter: 4, Loss: 0.8239, Best Loss: 0.6819
Restart: 1, Iter: 6, Loss: 0.7537, Best Loss: 0.6819
Restart: 1, Iter: 8, Loss: 0.6880, Best Loss: 0.6819
MSE: 0.08736331760883331
NLPD: 0.49492106437683103
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
skgpytorch-0.1.3.tar.gz
(268.1 kB
view details)
File details
Details for the file skgpytorch-0.1.3.tar.gz
.
File metadata
- Download URL: skgpytorch-0.1.3.tar.gz
- Upload date:
- Size: 268.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8821b1609b3e402481efec8fc1a2c02ef2221241aebe298ae97eecaa33fef01b |
|
MD5 | 40bb3d18540f78aa5e67b12f598cc7c4 |
|
BLAKE2b-256 | 8c644a1651f950d9f8390c6ac5987f5a250978911830df1c096767cb42d83dc3 |