A BoTorch wrapper for solving multiobjective optimization problems with an implementation of the qPOTS algorithm.
Project description
qPOTS: Batch Pareto Optimal Thompson Sampling
This repository contains the code for qPOTS, a multi-objective Bayesian optimization algorithm. Read the paper on arXiv: here.
This repository is maintained by the Computational Complex Engineered Systems Design Laboratory (CSDL) at Penn State.
Installing qPOTS
To install qPOTS with pip, run the following command in a terminal:
pip install qPOTS
This will install all of the necessary dependencies except for the MATLAB Engine, which is only needed for TS-EMO. To install the MATLAB Engine, follow the instructions at this link: Install MATLAB Engine for Python.
Note: The MATLAB Engine is only required if you plan on using TS-EMO and must be installed for Python>=3.10 and the corresponding MATLAB version on your machine (MATLAB installation required). The BoTorch implementation of the other acquisition functions (including qPOTS) only requires Python>=3.10 and the dependencies automatically installed by pip.
To build from source, clone the repository and run pip in the top-level directory:
git clone https://github.com/csdlpsu/qpots cd qpots pip install .
Quick Start
A quick demonstration of qPOTS is below. This code can be run to test your qPOTS installation.
For more thorough demonstrations on how qPOTS should be used, please see the examples/ directory.
import torch
import warnings
import time
from botorch.utils.transforms import unnormalize
warnings.filterwarnings('ignore')
device = torch.device("cpu")
from qpots.acquisition import Acquisition
from qpots.model_object import ModelObject
from qpots.function import Function
from qpots.utils.utils import expected_hypervolume
args = dict(
{
"ntrain": 20,
"iters": 50,
"reps": 20,
"q": 1,
"wd": ".",
"ref_point": torch.tensor([-300.0, -18.0]),
"dim": 2,
"nobj": 2,
"ncons": 0,
"nystrom": 0,
"nychoice": "pareto",
"ngen": 10,
}
)
tf = Function('branincurrin', dim=args["dim"], nobj=args["nobj"])
f = tf.evaluate
bounds = tf.get_bounds()
torch.manual_seed(1023)
train_x = torch.rand([args["ntrain"], args["dim"]], dtype=torch.float64)
train_y = f(unnormalize(train_x, bounds))
gps = ModelObject(train_x=train_x, train_y=train_y, bounds=bounds, nobj=args["nobj"], ncons=0, device=device)
gps.fit_gp()
acq = Acquisition(tf, gps, device=device, q=args["q"])
for i in range(args["iters"]):
t1 = time.time()
newx = acq.qpots(bounds, i, **args)
t2 = time.time()
newy = f(unnormalize(newx.reshape(-1, args["dim"]), bounds))
hv, _ = expected_hypervolume(gps, ref_point=args['ref_point'])
print(f"Iteration: {i}, New candidate: {newx}, Time: {t2 - t1}, HV: {hv}")
train_x = torch.row_stack([train_x, newx.view(-1, args["dim"])])
train_y = torch.row_stack([train_y, newy])
gps = ModelObject(train_x, train_y, bounds, args["nobj"], args["ncons"], device=device)
gps.fit_gp()
This code prints the results to the terminal. If this works, then congratulations, you have successfully installed qPOTS!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file qpots-1.0.3.tar.gz.
File metadata
- Download URL: qpots-1.0.3.tar.gz
- Upload date:
- Size: 55.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
784aacf42693ff5844f7ac5ea00e5107662208f8debf57147b961710e8fa4f1b
|
|
| MD5 |
758898bf849ad7933839982c0e65bd9a
|
|
| BLAKE2b-256 |
9e709e90780c58f387d766b1d714d0bc33ec2c64dcfa1589d420876881719d1a
|
File details
Details for the file qPOTS-1.0.3-py3-none-any.whl.
File metadata
- Download URL: qPOTS-1.0.3-py3-none-any.whl
- Upload date:
- Size: 39.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7d5169c75582c3b7d7955af74ffc252f12676fb09fa7be2a1839d450a031180a
|
|
| MD5 |
d5146d2706fadcc10d4ebbe47837d320
|
|
| BLAKE2b-256 |
79acb2c40e79a66d864383b9a13d2b5a8abc084cad73f73b5dd90fc5433da5d5
|