Skip to main content

PyTorch implementation of Bezier simplex fitting

Project description

pytorch-bsf

pipy conda documents

PyTorch implementation of Bezier simplex fitting.

The Bezier simplex is a high-dimensional generalization of the Bezier curve. Mathematically, it is a polynomial map from a simplex to a Euclidean space determined by a set of vectors called the control points. This package provides an algorithm to fit a Bezier simplex to given data points.

See the following papers for technical details.

  • Kobayashi, K., Hamada, N., Sannai, A., Tanaka, A., Bannai, K., & Sugiyama, M. (2019). Bézier Simplex Fitting: Describing Pareto Fronts of´ Simplicial Problems with Small Samples in Multi-Objective Optimization. Proceedings of the AAAI Conference on Artificial Intelligence, 33(01), 2304-2313. https://doi.org/10.1609/aaai.v33i01.33012304
  • Tanaka, A., Sannai, A., Kobayashi, K., & Hamada, N. (2020). Asymptotic Risk of Bézier Simplex Fitting. Proceedings of the AAAI Conference on Artificial Intelligence, 34(03), 2416-2424. https://doi.org/10.1609/aaai.v34i03.5622

Requirements

Python 3.8, 3.9, 3.10.

Installation-free Execution

Download the latest Miniconda and install it. Then, install MLflow on your conda environment:

conda install mlflow

Run the following command:

mlflow run https://github.com/rafcc/pytorch-bsf \
  -P data=data.tsv \
  -P label=label.tsv \
  -P degree=3

which automatically sets up the environment and runs an experiment:

  1. Download the latest pytorch-bsf into a temporary directory.
  2. Create a new conda environment and install dependencies in it.
  3. Run an experiment on the temporary directory and environment.
Parameter Type Default Description
data path required The data file. The file should contain a numerical matrix in the TSV format: each row represents a record that consists of features separated by Tabs or spaces.
label path required The label file. The file should contain a numerical matrix in the TSV format: each row represents a record that consists of outcomes separated by Tabs or spaces.
degree int (x >= 1) required The degree of the Bezier simplex.
header int (x >= 0) 0 The number of header lines in data/label files.
delimiter str " " The delimiter of values in data/label files.
normalize "max", "std", "quantile" None The data normalization: max scales each feature as the minimum is 0 and the maximum is 1, suitable for uniformly distributed data; std does as the mean is 0 and the standard deviation is 1, suitable for nonuniformly distributed data; quantile does as 5%-quantile is 0 and 95%-quantile is 1, suitable for data containing outliers; None does not perform any scaling, suitable for pre-normalized data.
split_ratio float (0.0 < x < 1.0) 0.5 The ratio of training data against validation data.
batch_size int (x >= 0) 0 The size of minibatch. The default uses all records in a single batch.
max_epochs int (x >= 1) 1000 The number of epochs to stop training.
accelerator "auto", "cpu", "gpu", etc. "auto" Accelerator to use. See PyTorch Lightning documentation.
devices int (x >= -1) None The number of accelerators to use. By default, use all available devices. See PyTorch Lightning documentation.
num_nodes int (x >= 1) 1 The number of compute nodes to use. See PyTorch Lightning documentation.
strategy "dp", "ddp", "ddp_spawn", etc. None Distributed strategy. See PyTorch Lightning documentation.
loglevel int (0 <= x <= 2) 2 What objects to be logged. 0: nothing; 1: metrics; 2: metrics and models.

Installation

pip install pytorch-bsf

Fitting via CLI

This package provides a command line interface to train a Bezier simplex with a dataset file.

Execute the torch_bsf module:

python -m torch_bsf \
  --data=data.tsv \
  --label=label.tsv \
  --degree=3

Fitting via Script

Train a model by fit(), and call the model to predict.

import torch
import torch_bsf

# Prepare training data
ts = torch.tensor(  # parameters on a simplex
    [
        [3/3, 0/3, 0/3],
        [2/3, 1/3, 0/3],
        [2/3, 0/3, 1/3],
        [1/3, 2/3, 0/3],
        [1/3, 1/3, 1/3],
        [1/3, 0/3, 2/3],
        [0/3, 3/3, 0/3],
        [0/3, 2/3, 1/3],
        [0/3, 1/3, 2/3],
        [0/3, 0/3, 3/3],
    ]
)
xs = 1 - ts * ts  # values corresponding to the parameters

# Train a model
bs = torch_bsf.fit(params=ts, values=xs, degree=3)

# Predict by the trained model
t = [[0.2, 0.3, 0.5]]
x = bs(t)
print(f"{t} -> {x}")

Documents

See documents for more details. https://rafcc.github.io/pytorch-bsf/

Author

RIKEN AIP-FUJITSU Collaboration Center (RAFCC)

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch-bsf-0.0.1.tar.gz (8.6 kB view details)

Uploaded Source

Built Distribution

pytorch_bsf-0.0.1-py3-none-any.whl (9.6 kB view details)

Uploaded Python 3

File details

Details for the file pytorch-bsf-0.0.1.tar.gz.

File metadata

  • Download URL: pytorch-bsf-0.0.1.tar.gz
  • Upload date:
  • Size: 8.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for pytorch-bsf-0.0.1.tar.gz
Algorithm Hash digest
SHA256 aae5c033b94c418b4e5ee1b96a8d9a5354f6772f6f9c53af5cb2257738e862aa
MD5 7abce72ab00d38d9292f525afaf0a67a
BLAKE2b-256 7a30ef3d4240f883422ee74b67986d65667f7ceb5682b093f9e6583b9ecea6d0

See more details on using hashes here.

File details

Details for the file pytorch_bsf-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: pytorch_bsf-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 9.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for pytorch_bsf-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 7e23b3e0eb84738e4a5466e0bc7cd95d3c57e50d26738f7b308a3792c0ba7b68
MD5 dab6bdb50e2dc71a0eb76af2901ed06a
BLAKE2b-256 479f90fc75f531053408c15db7d859898143a632697d238a1b044e3b61d7c91a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page