example description
Project description
skgpytorch
GPyTorch Models in Scikit-learn wrapper.
Example
import torch
from skgpytorch.models import ExactGPRegressor
from skgpytorch.metrics import mean_squared_error, negative_log_predictive_density
from gpytorch.kernels import RBFKernel, ScaleKernel
# Define a model
train_x = torch.rand(10, 1)
train_y = torch.rand(10)
test_x = torch.rand(10, 1)
test_y = torch.rand(10)
kernel = ScaleKernel(RBFKernel(ard_num_dims=train_x.shape[1]))
gp = ExactGPRegressor(train_x, train_y, kernel, random_state=0, device="cpu")
# Fit the model
gp.fit(n_iters=10, verbose=True, n_restarts=2, verbose_gap=2)
# Get the predictions
# f_mean, f_var = gp.predict(test_x)
# OR
pred_dist = gp.predict(test_x, dist_only=True)
# Calculate metrics
print("MSE:", mean_squared_error(pred_dist, test_x, test_y))
print("NLPD:", negative_log_predictive_density(pred_dist, test_x, test_y))
Restart: 0, Iter: 0, Loss: 1.0135, Best Loss: inf
Restart: 0, Iter: 2, Loss: 0.9371, Best Loss: inf
Restart: 0, Iter: 4, Loss: 0.8644, Best Loss: inf
Restart: 0, Iter: 6, Loss: 0.7978, Best Loss: inf
Restart: 0, Iter: 8, Loss: 0.7382, Best Loss: inf
Restart: 1, Iter: 0, Loss: 0.9626, Best Loss: 0.6819
Restart: 1, Iter: 2, Loss: 0.8948, Best Loss: 0.6819
Restart: 1, Iter: 4, Loss: 0.8239, Best Loss: 0.6819
Restart: 1, Iter: 6, Loss: 0.7537, Best Loss: 0.6819
Restart: 1, Iter: 8, Loss: 0.6880, Best Loss: 0.6819
MSE: 0.08736331760883331
NLPD: 0.49492106437683103
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
skgpytorch-0.1.5.tar.gz
(268.2 kB
view details)
File details
Details for the file skgpytorch-0.1.5.tar.gz
.
File metadata
- Download URL: skgpytorch-0.1.5.tar.gz
- Upload date:
- Size: 268.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/32.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.10.1 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.10.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 185baa2a79a101311dd10cba4edd415d3f4af12a1fdf623c87565ed1d14f82c9 |
|
MD5 | c962387750ddc83b39d576edb463fa41 |
|
BLAKE2b-256 | 5c8ec1c153f5f969b7619ff4b336eeb94b08258d7ff5bfe5b8dcb7f5973a1ff5 |