Skip to main content

Tool to easily perform quantile regression using deep learning (pytorch).

Project description

deep-q-reg

下の方に日本語の説明があります

Overview

  • Tool to easily perform quantile regression using deep learning (pytorch).
  • Automatically compensates for quantile order swapping defects in predicted data.
  • Customizable granularity, from "hyper-parameter unspecified mode" to "detailed parameter settings".
  • The number of data dimensions is set automatically at the first training with define-by-run.

Example usage

import deep_q_reg

# Prepare data
train_x = load_x_data()    # [[0.538], [0.469], ...]
train_y = load_y_data()    # [24.0, 21.6, ...]

# Parameters
params = {}    # This can also be omitted. See details below for specifying parameters

# Deep quantile regression [deep_q_reg]
dqr = deep_q_reg.Deep_Q_Reg(params)
# Training [deep_q_reg]
dqr.train(train_x, train_y)
# Inference [deep_q_reg]
pred_y = dqr.predict(test_x)

Details of specifying parameters

  • Parameters are specified as follows. Omitted specifications are automatically filled in with default values.
params = {
    'normalize_x': True,    # Automatic normalization of x
    'quant_ls': [0.25, 0.5, 0.75],    # List of prediction target quantiles
    # Layer structure
    'layers': [
        {
            'activation': 'ReLU',    # Activation function name (Specify names under torch.nn such as Tanh, ReLU, Sigmoid)
            'out_n': 32    # Output dimension of the layer (Input dimension is automatically determined from training data or previous layer settings)
        },
        {'activation': 'ReLU', 'out_n': 32}
    ],
    # Training parameters
    'mini_batch_n': 10000,    # Number of iterations for training (mini-batch training)
    'mini_batch_size': 512    # Mini-batch size
}

概要

  • 深層学習(pytorch)による分位点回帰を簡単に実施できるツール
  • 推論データにおける分位点順序の入れ替わり不具合を自動的に補正する
  • 「パラメータ等指定無し」から「詳細なパラメータ設定」まで自由なカスタマイズ粒度で扱える
  • データ次元数の設定がdefine-by-runで初回学習時に自動で設定される

使用例

import deep_q_reg

# データ準備
train_x = load_x_data()	# [[0.538], [0.469], ...]
train_y = load_y_data()	# [24.0, 21.6, ...]

# パラメータ
params = {}	# このように省略してもよい。詳細な指定の仕方は後述

# 深層分位点回帰 [deep_q_reg]
dqr = deep_q_reg.Deep_Q_Reg(params)
# 学習 [deep_q_reg]
dqr.train(train_x, train_y)
# 推論 [deep_q_reg]
pred_y = dqr.predict(test_x)

paramsの指定詳細

  • paramsは下記のように指定します。省略された指定値は自動的にdefault値が補完されます。
params = {
    'normalize_x': True,    # xの自動正規化
    'quant_ls': [0.25, 0.5, 0.75],    # 予測対象分位点一覧
    # 層構成
    'layers': [
        {
            'activation': 'ReLU',    # 活性化関数名 (torch.nn配下の名前を指定する。Tanh, ReLU, Sigmoid など)
            'out_n': 32    # 層の出力次元数 (入力次元数は学習データや前層設定から自動的に判断される)
        },
        {'activation': 'ReLU', 'out_n': 32}
    ],
    # 学習パラメータ
    'mini_batch_n': 10000,    # 繰り返し学習回数 (ミニバッチ学習)
    'mini_batch_size': 512    # ミニバッチのサイズ
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

deep-q-reg-0.0.3.tar.gz (6.3 kB view details)

Uploaded Source

Built Distribution

deep_q_reg-0.0.3-py3-none-any.whl (7.4 kB view details)

Uploaded Python 3

File details

Details for the file deep-q-reg-0.0.3.tar.gz.

File metadata

  • Download URL: deep-q-reg-0.0.3.tar.gz
  • Upload date:
  • Size: 6.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/3.10.0 pkginfo/1.7.0 requests/2.22.0 requests-toolbelt/0.9.1 tqdm/4.64.1 CPython/3.8.8

File hashes

Hashes for deep-q-reg-0.0.3.tar.gz
Algorithm Hash digest
SHA256 531f16d0d998846c79019bdf58e2854f03b57a4b68c465ed595e516f68a497b7
MD5 2ec7ff5e143456d24ca087e6e4ad050a
BLAKE2b-256 095ee5f6e95f770c4980a0750771f7ff22a9ff54086840ccf4cc5729d18d7347

See more details on using hashes here.

File details

Details for the file deep_q_reg-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: deep_q_reg-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 7.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/3.10.0 pkginfo/1.7.0 requests/2.22.0 requests-toolbelt/0.9.1 tqdm/4.64.1 CPython/3.8.8

File hashes

Hashes for deep_q_reg-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 6bd789df0754f3490712654312070e19eea06ec1ac4e5e6f92950e60dabe6f1d
MD5 c6c7ee512b81670e471e81e8e9ae5b86
BLAKE2b-256 ebb88de933e8a8f02ddc3eed25cbfa01ad48c5fa9a7b18250e1326b793c9986b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page