W-DBO Algotithm for Dynamic Bayesian Optimization
Project description
A Wasserstein Distance-Based Dynamic Bayesian Optimization Algorithm (W-DBO)
TL;DR
W-DBO is a Dynamic Bayesian Optimization (DBO) algorithm. It is a well-suited optimizer for a dynamic black-box function.
In fact, W-DBO is the first DBO algorithm able to simultaneously (i) capture complex spatio-temporal dynamics and (ii) remove stale and/or irrelevant observations from its dataset. As a consequence, it thrives in any DBO task, even those with infinite time horizons, providing at the same time good performance and a high sampling frequency.
W-DBO is authored by Anthony Bardou, Patrick Thiran and Giovanni Ranieri. The code in this repository is based on BoTorch [1] and Eigen [2].
Contents
Citing this Work
If you use this code or if you'd like to reference this work, please cite the following paper [3] with the following BibTeX entry:
@article{bardou2024too,
title={This Too Shall Pass: Removing Stale Observations in Dynamic Bayesian Optimization},
author={Bardou, Anthony and Thiran, Patrick and Ranieri, Giovanni},
journal={arXiv preprint arXiv:2405.14540},
year={2024}
}
Installation
W-DBO has its own PyPI package. To install it, open your favorite command line and run
pip install wdbo-algo
Alternatively, you can download the code from this repository.
Quick Start
In this section, we provide all the relevant details to use W-DBO and a minimal working example.
W-DBO can be used by instantiating the wdbo_algo.optimizer.WDBOOptimizer class as follows:
from wdbo_algo.optimizer import WDBOOptimizer
import gpytorch
wdbo_optimizer = WDBOOptimizer(
spatial_domain, # The parameters of the function
spatial_kernel, temporal_kernel, # Gpytorch.kernels classes for capturing space and time dynamics
spatial_kernel_args=[], # Optional arguments to pass when instantiating the gpytorch spatial kernel
temporal_kernel_args=[], # Optional arguments to pass when instantiating the gpytorch temporal kernel
n_initial_observations=15, # Number of observations before removing irrelevant observations from the dataset
alpha=0.25 # Hyperparameter controlling the removal of irrelevant observations
)
We describe the arguments in more details below. Let $f : \mathcal{S} \times \mathcal{T} \to \mathbb{R}$, where $\mathcal{S} \subseteq \mathbb{R}^d$ is the $d$-dimensional space domain and $\mathcal{T} \subseteq \mathbb{R}$ is the time domain. Then,
spatial_domaindescribes the space domain $\mathcal{S}$. More precisely, it is an array of shape $(d, 2)$ describing a $d$-dimensional hyperrectangle. The infimum and supremum for the $i$th dimension are inspatial_domain[i-1, 0]andspatial_domain[i-1, 1], respectively.spatial_kernelis a class fromgpytorch.kernelsdescribing the correlation of function values in the spatial domain. Right now, only the two most popular kernels are supported, namelygpytorch.kernels.RBFKernel(the squared-exponential covariance function) andgpytorch.kernels.MaternKernel(the Matérn covariance function).temporal_kernelis a class fromgpytorch.kernelsdescribing the correlation of function values in the temporal domain. Right now, only the two most popular kernels are supported, namelygpytorch.kernels.RBFKernel(the squared-exponential covariance function) andgpytorch.kernels.MaternKernel(the Matérn covariance function).spatial_kernel_argscan be used to pass arguments to the constructor of thegpytorch.kernelsclass. It is ignored forspatial_kernel=gpytorch.kernels.RBFKernel, but must contain thenuparameter (half-integer only) forspatial_kernel=gpytorch.kernels.MaternKernel.temporal_kernel_argscan be used to pass arguments to the constructor of thegpytorch.kernelsclass. It is ignored fortemporal_kernel=gpytorch.kernels.RBFKernel, but must contain thenuparameter (half-integer only) fortemporal_kernel=gpytorch.kernels.MaternKernel.n_initial_observationsis the number of observations to collect before beginning the removal of irrelevant observations. This number must be large enough so that the hyperparameters of the covariance functions can be decently approximated by MLE.alphais the hyperparameter controlling the removal of irrelevant observations. The larger $\alpha$, the more observations are removed. The recommended value in [3] isalpha=0.25.
The wdbo_algo.optimizer.WDBOOptimizer class has three important methods:
next_query(t)searches for an input to query at timetby finding an exploration-exploitation trade-off about the objective function,tell(x, t, y)adds a new observation $((\bm x, t), y)$ to the dataset, where $\bm x \in \mathcal{S}$, $t \in \mathcal{T}$, $y = f(\bm x, t) + \epsilon$ and $\epsilon \sim \mathcal{N}(0, \sigma^2)$clean(t)removes irrelevant observations from the dataset
Putting it all together, here is a minimal working example:
import numpy as np
from time import sleep, time
import gpytorch
from wdbo_algo.optimizer import WDBOOptimizer
# The expensive, noisy objective function
def my_noisy_objective_function(x, t):
sleep(0.1)
x = list(x) + [5.0 * t / 8.0 - 1.0]
print(x)
f = 0.0
for i in range(2):
f += 100 * ((x[i + 1] - x[i] ** 2) ** 2) + (1 - x[i]) ** 2
return -f + np.random.normal(0.0, np.sqrt(0.1))
if __name__ == "__main__":
# Spatial domain
spatial_domain = np.array([[-1.0, 1.5] for _ in range(2)])
# Optimizer instantiation
optimizer = WDBOOptimizer(
spatial_domain,
gpytorch.kernels.MaternKernel, gpytorch.kernels.MaternKernel,
spatial_kernel_args=[2.5], temporal_kernel_args=[2.5],
n_initial_observations=15
)
# Start time
start_time = 0.0
# Experiment duration (minutes)
xp_duration = 4.0
# End time
end_time = start_time + xp_duration
# Optimization loop
current_time = start_time
# As long as the experiment is not over
while current_time < end_time:
start = time()
# Find a relevant input to query
x = optimizer.next_query(current_time)
# Query the objective function
y = my_noisy_objective_function(x, current_time)
# Add the new observation to the dataset
optimizer.tell(x, current_time, y)
# Clean the dataset from irrelevant observations
optimizer.clean(current_time)
end = time()
current_time += (end - start) / 60.0
Troubleshooting
Unavailability in MacOS
W-DBO uses C++ code bound with Python code to run (for more details, please see the dedicated GitHub repository).
To compute the relevancy of observations, W-DBO makes a heavy use of Bessel functions, which are not implemented in stdlib from Apple Clang nor in stdlib from Clang++. Consequently, the C++ bindings are not available for MacOS. You can still run this code within a Docker or a virtual machine.
References
[1] Balandat, M., Karrer, B., Jiang, D., Daulton, S., Letham, B., Wilson, A. G., & Bakshy, E. (2020). BoTorch: A framework for efficient Monte-Carlo Bayesian optimization. Advances in neural information processing systems, 33, 21524-21538.
[2] G. Guennebaud, B. Jacob, et al. Eigen. http://eigen.tuxfamily.org 3.1 (2010)
[3] Bardou, A., Thiran, P., & Ranieri, G. (2024). This Too Shall Pass: Removing Stale Observations in Dynamic Bayesian Optimization. arXiv preprint arXiv:2405.14540.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file wdbo_algo-1.0.7.tar.gz.
File metadata
- Download URL: wdbo_algo-1.0.7.tar.gz
- Upload date:
- Size: 11.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.11.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
529e1233bb10454bde878b1e0a17dce7d7558e8615e76bfa300b5f77fdbf20b7
|
|
| MD5 |
e37188648ea9a37f6c398cb5aa80d499
|
|
| BLAKE2b-256 |
e8da051fbbac3f34045d7f18f5f946d44e4d9cb0cc1a7963fc2b2a38dd8f02b6
|
File details
Details for the file wdbo_algo-1.0.7-py3-none-any.whl.
File metadata
- Download URL: wdbo_algo-1.0.7-py3-none-any.whl
- Upload date:
- Size: 9.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.11.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
49c64f702b45f4263ef024e2378c2b9b5efa0d96dddee7c99fa1458f41a2e43a
|
|
| MD5 |
876867888480aa72b9906c91364f0318
|
|
| BLAKE2b-256 |
a4a7364820ac4fb58a41d3eafbc75568a1f19ec76e6ebf236f24446193799be4
|