Skip to main content

A Python library for uncertainty estimation in supervised learning tasks

Project description

UncertaintyPlayground

CI Test Suite Python Version PyPI License

Installation

Requirements:

  • Python >= 3.8
  • PyTorch == 2.0.1
  • GPyTorch == 1.10
  • Numpy == 1.24.0
  • Seaborn == 0.12.2

Use PyPI to install the package:

pip install uncertaintyplayground

or alterntively, to use the development version, install directly from GitHub:

pip install git+https://github.com/unco3892/UncertaintyPlayground.git

Usage

You can train and visualize the results of the models in the following way (this example uses the California Housing dataset from Sklearn):

from uncertaintyplayground.trainers.svgp_trainer import SparseGPTrainer
from uncertaintyplayground.trainers.mdn_trainer import MDNTrainer
from uncertaintyplayground.predplot.svgp_predplot import compare_distributions_svgpr
from uncertaintyplayground.predplot.mdn_predplot import compare_distributions_mdn
from uncertaintyplayground.predplot.grid_predplot import plot_results_grid
import torch
import numpy as np
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split

# Load the California Housing dataset
california = fetch_california_housing()

# Convert X and y to numpy arrays of float32
X = np.array(california.data, dtype=np.float32)
y = np.array(california.target, dtype=np.float32)

# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(1)

# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# SVGPR: Initialize and train a SVGPR model with 100 inducing points
california_trainer_svgp = SparseGPTrainer(X_train, y_train, num_inducing_points=100, num_epochs=30, batch_size=512, lr=0.1, patience=3)
california_trainer_svgp.train()

# MDN: Initialize and train an MDN model
california_trainer_mdn = MDNTrainer(X_train, y_train, num_epochs=100, lr=0.001, dense1_units=50, n_gaussians=10)
california_trainer_mdn.train()

# SVPGR: Visualize the SVGPR's predictions for multiple instances
plot_results_grid(trainer=california_trainer_svgp, compare_func=compare_distributions_svgpr, X_test=X_test, Y_test=y_test, indices=[900, 500], ncols=2)

# MDN: Visualize the MDN's predictions for multiple instances
plot_results_grid(trainer=california_trainer_mdn, compare_func=compare_distributions_mdn, X_test=X_test, Y_test=y_test, indices=[900, 500], ncols=2)

You can find another example for MDN in the examples folder.

Contributors

This library is maintained by Ilia Azizi (University of Lausanne). Any other contributors are welcome to join! Feel free to get in touch with (contact links on my website).

Citation

If you use this package in your research, please cite our work:

UncertaintyPlayground: A Fast and Simplified Python Library for Uncertainty Estimation , Ilia Azizi, arXiv:2310.15281

@misc{azizi2023uncertaintyplayground,
      title={UncertaintyPlayground: A Fast and Simplified Python Library for Uncertainty Estimation}, 
      author={Ilia Azizi},
      year={2023},
      eprint={2310.15281},
      archivePrefix={arXiv},
      primaryClass={stat.ML}
}

License

Please see the project MIT licensed here.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

UncertaintyPlayground-0.1.2.tar.gz (16.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

UncertaintyPlayground-0.1.2-py3-none-any.whl (26.4 kB view details)

Uploaded Python 3

File details

Details for the file UncertaintyPlayground-0.1.2.tar.gz.

File metadata

  • Download URL: UncertaintyPlayground-0.1.2.tar.gz
  • Upload date:
  • Size: 16.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.12

File hashes

Hashes for UncertaintyPlayground-0.1.2.tar.gz
Algorithm Hash digest
SHA256 78dbd102e83d3e8a37c4e392a5bbbf2516bfa1113aa6bc88288b5f6ae801882d
MD5 a6255adf049c73b35acf9b71d9c7cdba
BLAKE2b-256 3aec1797423e5eb491cdc002620b63f6506f8518331bde726c43912f19077e6e

See more details on using hashes here.

File details

Details for the file UncertaintyPlayground-0.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for UncertaintyPlayground-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 8d07e365c9806af72dc1bfd9e5e536230cb7804a8b0e3ab405db5bb3e4f3513b
MD5 65eb539bcb74f91e42c902bfe8bd1bc0
BLAKE2b-256 a230ba47df7a21f0aa7d930e6163a8c7fc0173753f847091997b32aeda992612

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page