PyTorch implementation of Bezier simplex fitting
Project description
pytorch-bsf
PyTorch implementation of Bezier simplex fitting.
The Bezier simplex is a high-dimensional generalization of the Bezier curve. Mathematically, it is a polynomial map from a simplex to a Euclidean space determined by a set of vectors called the control points. This package provides an algorithm to fit a Bezier simplex to given data points.
See the following papers for technical details.
- Kobayashi, K., Hamada, N., Sannai, A., Tanaka, A., Bannai, K., & Sugiyama, M. (2019). Bézier Simplex Fitting: Describing Pareto Fronts of´ Simplicial Problems with Small Samples in Multi-Objective Optimization. Proceedings of the AAAI Conference on Artificial Intelligence, 33(01), 2304-2313. https://doi.org/10.1609/aaai.v33i01.33012304
- Tanaka, A., Sannai, A., Kobayashi, K., & Hamada, N. (2020). Asymptotic Risk of Bézier Simplex Fitting. Proceedings of the AAAI Conference on Artificial Intelligence, 34(03), 2416-2424. https://doi.org/10.1609/aaai.v34i03.5622
Requirements
Python 3.8, 3.9, 3.10.
Installation-free Execution
Download the latest Miniconda and install it. Then, install MLflow on your conda environment:
conda install mlflow
Run the following command:
mlflow run https://github.com/rafcc/pytorch-bsf \
-P data=data.tsv \
-P label=label.tsv \
-P degree=3
which automatically sets up the environment and runs an experiment:
- Download the latest pytorch-bsf into a temporary directory.
- Create a new conda environment and install dependencies in it.
- Run an experiment on the temporary directory and environment.
Parameter | Type | Default | Description |
---|---|---|---|
data | path | required | The data file. The file should contain a numerical matrix in the TSV format: each row represents a record that consists of features separated by Tabs or spaces. |
label | path | required | The label file. The file should contain a numerical matrix in the TSV format: each row represents a record that consists of outcomes separated by Tabs or spaces. |
degree | int (x >= 1) | required | The degree of the Bezier simplex. |
header | int (x >= 0) | 0 |
The number of header lines in data/label files. |
delimiter | str | " " |
The delimiter of values in data/label files. |
normalize | "max" , "std" , "quantile" |
None |
The data normalization: max scales each feature as the minimum is 0 and the maximum is 1, suitable for uniformly distributed data; std does as the mean is 0 and the standard deviation is 1, suitable for nonuniformly distributed data; quantile does as 5%-quantile is 0 and 95%-quantile is 1, suitable for data containing outliers; None does not perform any scaling, suitable for pre-normalized data. |
split_ratio | float (0.0 < x < 1.0) | 0.5 |
The ratio of training data against validation data. |
batch_size | int (x >= 0) | 0 |
The size of minibatch. The default uses all records in a single batch. |
max_epochs | int (x >= 1) | 1000 |
The number of epochs to stop training. |
accelerator | "auto" , "cpu" , "gpu" , etc. |
"auto" |
Accelerator to use. See PyTorch Lightning documentation. |
devices | int (x >= -1) | None |
The number of accelerators to use. By default, use all available devices. See PyTorch Lightning documentation. |
num_nodes | int (x >= 1) | 1 |
The number of compute nodes to use. See PyTorch Lightning documentation. |
strategy | "dp" , "ddp" , "ddp_spawn" , etc. |
None |
Distributed strategy. See PyTorch Lightning documentation. |
loglevel | int (0 <= x <= 2) | 2 |
What objects to be logged. 0 : nothing; 1 : metrics; 2 : metrics and models. |
Installation
pip install pytorch-bsf
Fitting via CLI
This package provides a command line interface to train a Bezier simplex with a dataset file.
Execute the torch_bsf
module:
python -m torch_bsf \
--data=data.tsv \
--label=label.tsv \
--degree=3
Fitting via Script
Train a model by fit()
, and call the model to predict.
import torch
import torch_bsf
# Prepare training data
ts = torch.tensor( # parameters on a simplex
[
[3/3, 0/3, 0/3],
[2/3, 1/3, 0/3],
[2/3, 0/3, 1/3],
[1/3, 2/3, 0/3],
[1/3, 1/3, 1/3],
[1/3, 0/3, 2/3],
[0/3, 3/3, 0/3],
[0/3, 2/3, 1/3],
[0/3, 1/3, 2/3],
[0/3, 0/3, 3/3],
]
)
xs = 1 - ts * ts # values corresponding to the parameters
# Train a model
bs = torch_bsf.fit(params=ts, values=xs, degree=3)
# Predict by the trained model
t = [[0.2, 0.3, 0.5]]
x = bs(t)
print(f"{t} -> {x}")
Documents
See documents for more details. https://rafcc.github.io/pytorch-bsf/
Author
RIKEN AIP-FUJITSU Collaboration Center (RAFCC)
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file pytorch-bsf-0.0.1.tar.gz
.
File metadata
- Download URL: pytorch-bsf-0.0.1.tar.gz
- Upload date:
- Size: 8.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.9.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | aae5c033b94c418b4e5ee1b96a8d9a5354f6772f6f9c53af5cb2257738e862aa |
|
MD5 | 7abce72ab00d38d9292f525afaf0a67a |
|
BLAKE2b-256 | 7a30ef3d4240f883422ee74b67986d65667f7ceb5682b093f9e6583b9ecea6d0 |
File details
Details for the file pytorch_bsf-0.0.1-py3-none-any.whl
.
File metadata
- Download URL: pytorch_bsf-0.0.1-py3-none-any.whl
- Upload date:
- Size: 9.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.9.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7e23b3e0eb84738e4a5466e0bc7cd95d3c57e50d26738f7b308a3792c0ba7b68 |
|
MD5 | dab6bdb50e2dc71a0eb76af2901ed06a |
|
BLAKE2b-256 | 479f90fc75f531053408c15db7d859898143a632697d238a1b044e3b61d7c91a |