Paretoflow is a Python package for offline multi-objective optimization using Generative Flow Models with Multi Predictors Guidance to approximate the Pareto front.
Project description
ParetoFlow
ParetoFlow is a Python package for offline multi-objective optimization using Generative Flow Models with Multi Predictors Guidance to approximate the Pareto front.
Installation
conda create -n paretoflow python=3.10
conda activate paretoflow
pip install paretoflow
Usage
We accept .py files for input features and labels, where the continuous features has shape (n_samples, n_dim), and the discrete features has shape (n_samples, seq_len).
The labels are the objective values, with shape (n_samples, n_obj).
When having discrete features, we need to convert the discrete features to continuous logits, as stated in the ParetoFlow paper. The implementation follows the design-bench.
In our implementation, we support both z-score normalization and min-max normalization. In our paper, we use z-score normalization for training the proxies and flow matching model. Min-max normalization is used for calculating the hypervolume, aligining with offline-moo.
If you have your data as x.npy and y.npy, you can use the following code to train the proxies and flow matching model (continuous features for illustration, see the examples/discrete_examples.py for discrete features):
import numpy as np
from paretoflow import ParetoFlow
# Load the data
all_x = np.load("examples/data/zdt1-x-0.npy")
all_y = np.load("examples/data/zdt1-y-0.npy")
# If need to train the flow matching and proxies models
# Initialize the ParetoFlow sampler
pf = ParetoFlow(
task_name="zdt1",
input_x=all_x,
input_y=all_y,
x_lower_bound=np.array([0.0] * all_x.shape[1]),
x_upper_bound=np.array([1.0] * all_x.shape[1]),
)
# Sample the Pareto Set
res_x, res_y = pf.sample()
print(len(res_x))
Or you can load the pre-trained flow matching and proxies models:
import numpy as np
import torch
from paretoflow import ParetoFlow, VectorFieldNet, FlowMatching, MultipleModels
# If load pre-trained flow matching and proxies models
# Initialize the ParetoFlow sampler
vnet = VectorFieldNet(all_x.shape[1])
fm_model = FlowMatching(vnet=vnet, sigma=0.0, D=all_x.shape[1], T=1000)
fm_model = torch.load("saved_fm_models/zdt1.model")
# Create the proxies model and load the saved model
proxies_model = MultipleModels(
n_dim=all_x.shape[1],
n_obj=all_y.shape[1],
train_mode="Vanilla",
hidden_size=[2048, 2048],
save_dir="saved_proxies/",
save_prefix="MultipleModels-Vanilla-zdt1",
)
proxies_model.load()
pf = ParetoFlow(
task_name="zdt1",
input_x=all_x,
input_y=all_y,
x_lower_bound=np.array([0.0] * all_x.shape[1]),
x_upper_bound=np.array([1.0] * all_x.shape[1]),
load_pretrained_fm=True,
load_pretrained_proxies=True,
fm_model=fm_model,
proxies=proxies_model,
)
res_x, res_y = pf.sample()
print(len(res_x))
More Importantly, we also allow users to pass in their own pretrained flow matching and proxies models. We require the flow matching model to be a nn.Module object and also pass in two key arguments vnet and time_embedding, which are both nn.Module objects. The vnet is the network approximation for the vector field in the flow matching model, and the time_embedding is a mapping from continuous time between [0, 1] to the embedding space. See more details in the docstrings of the ParetoFlow class.
Citation
If you find ParetoFlow useful in your research, please consider citing:
@misc{yuan2024paretoflowguidedflowsmultiobjective,
title={ParetoFlow: Guided Flows in Multi-Objective Optimization},
author={Ye Yuan and Can Chen and Christopher Pal and Xue Liu},
year={2024},
eprint={2412.03718},
archivePrefix={arXiv},
primaryClass={cs.CE},
url={https://arxiv.org/abs/2412.03718},
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file paretoflow-0.1.3.tar.gz.
File metadata
- Download URL: paretoflow-0.1.3.tar.gz
- Upload date:
- Size: 25.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.10.15
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
eb57f9e3ca191433558a196de2b3d094be9fd80f2e5b7a38044727b68a51291e
|
|
| MD5 |
c62bfc6aa03aaad8c16f72665087baf2
|
|
| BLAKE2b-256 |
a2e6ac2502cf65514a6244ebd53f54664407910fda1fb643f96a1066dbf14604
|