example description
Project description
skgpytorch
GPyTorch Models in Scikit-learn wrapper.
Example
import torch
from skgpytorch.models import ExactGPRegressor
from skgpytorch.metrics import mean_squared_error, negative_log_predictive_density
from gpytorch.kernels import RBFKernel, ScaleKernel
# Define a model
train_x = torch.rand(10, 1)
train_y = torch.rand(10)
test_x = torch.rand(10, 1)
test_y = torch.rand(10)
kernel = ScaleKernel(RBFKernel(ard_num_dims=train_x.shape[1]))
gp = ExactGPRegressor(train_x, train_y, kernel)
# Fit the model (This supports batch training of GP models as well)
gp.fit(n_epochs=2, verbose=True, n_restarts=1, verbose_gap=2, batch_size=10, lr=0.1, random_state=0)
# Get the predictions
pred_dist = gp.predict(test_x)
# Access properties of predictive distribution
pred_mean = pred_dist.mean # Mean
pred_var = pred_dist.variance # Variance
pred_stddev = pred_dist.stddev # Standard deviation
lower, upper = pred_dist.confidence_region() # 95% confidence region
# Calculate metrics (Soon this will be implemented in gpytorch itself)
print("MSE:", mean_squared_error(pred_dist, test_y))
print("NLPD:", negative_log_predictive_density(pred_dist, test_y))
Restart: 0, Iter: 0, Loss: 1.0135, Best Loss: inf
Restart: 0, Iter: 2, Loss: 0.9371, Best Loss: inf
Restart: 0, Iter: 4, Loss: 0.8644, Best Loss: inf
Restart: 0, Iter: 6, Loss: 0.7978, Best Loss: inf
Restart: 0, Iter: 8, Loss: 0.7382, Best Loss: inf
Restart: 1, Iter: 0, Loss: 0.9626, Best Loss: 0.6819
Restart: 1, Iter: 2, Loss: 0.8948, Best Loss: 0.6819
Restart: 1, Iter: 4, Loss: 0.8239, Best Loss: 0.6819
Restart: 1, Iter: 6, Loss: 0.7537, Best Loss: 0.6819
Restart: 1, Iter: 8, Loss: 0.6880, Best Loss: 0.6819
MSE: 0.08736331760883331
NLPD: 0.49492106437683103
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
skgpytorch-0.2.0.tar.gz
(8.6 kB
view details)
File details
Details for the file skgpytorch-0.2.0.tar.gz
.
File metadata
- Download URL: skgpytorch-0.2.0.tar.gz
- Upload date:
- Size: 8.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/32.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.11.1 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.10.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1c2d13715786ce18169997afcb040ca239ea0f47b2507706d33603166afa4dcc |
|
MD5 | ea3e29da21a10d837666ff558df19a4e |
|
BLAKE2b-256 | e3f79284f60991b7fb6fdc89b50c924271591a1eda7bc18b865ba1fec0a8c286 |