Skip to main content

Gated Asymmetric Linear Regression - 特徴量に応じてコスト比を学習する回帰モデル

Project description

Gated Asymmetric Linear Regression (GALR)

「上振れ(過大予測)」と「下振れ(過小予測)」のペナルティを固定せず、特徴量に応じて学習する"ゲート付き非対称線形回帰"パッケージ。

概要

従来の cost-sensitive 回帰(上振れ/下振れの重み固定)ではなく、状況(特徴量 x)に応じてコスト比が変化する現実の意思決定をモデル化します。

  • sklearn 互換 API(BaseEstimator, RegressorMixin
  • プロダクション運用(再学習・評価・監視)まで想定
  • 最小限の依存関係(numpy, scikit-learn)

インストール

pip install galr

開発版をインストールする場合:

git clone https://github.com/yourusername/galr.git
cd galr
pip install -e .

使用方法

from galr import GALRRegressor
import numpy as np

# データの準備
X = np.random.randn(100, 5)
y = np.random.randn(100)

# モデルの学習
model = GALRRegressor(
    gate='linear',
    fit_intercept=True,
    optimizer='sgd',
    lr=0.01,
    n_iter=1000,
    lambda_beta=0.01,
    lambda_gate=0.01,
    epsilon=1e-6,
    random_state=42
)
model.fit(X, y)

# 予測
y_pred = model.predict(X)

# ゲート関数の値を取得(オプション)
gate_values = model.get_gate_values(X)

モデル詳細

予測器

\hat{y} = x^\top \beta + b

ゲート関数

ゲート関数 $g(x)$ が「この状況では下振れが痛い/上振れが痛い」を学習します。

w_\mathrm{under}(x) = \mathrm{softplus}(g(x)) + \epsilon
w_\mathrm{over}(x) = \mathrm{softplus}(-g(x)) + \epsilon

損失関数

L(\beta, b, \theta) = \frac{1}{n} \sum_i \Big[ \mathbb{1}(e_i > 0) \cdot w_\mathrm{under}(x_i) + \mathbb{1}(e_i < 0) \cdot w_\mathrm{over}(x_i) \Big] e_i^2 + \lambda_\beta \|\beta\|_2^2 + \lambda_g \|\theta\|_2^2

パラメータ

  • gate: 'linear' | 'mlp'(現在は 'linear' のみ対応)
  • fit_intercept: bool - 切片を学習するか
  • optimizer: 'sgd' | 'adam'(現在は 'sgd' のみ対応)
  • lr: float - 学習率
  • n_iter: int - イテレーション数
  • tol: float - 収束判定の閾値
  • lambda_beta: float - 回帰係数のL2正則化係数
  • lambda_gate: float - ゲートパラメータのL2正則化係数
  • epsilon: float - softplusの下限値
  • standardize: bool - 内部でStandardScalerを使用するか
  • random_state: int - 乱数シード

ライセンス

MIT License

開発状況

現在は MVP(最小実装)段階です。将来的には以下を追加予定:

  • MLPゲートの実装
  • Adamオプティマイザの実装
  • より高度な最適化手法
  • 詳細なドキュメントとチュートリアル

開発者向け情報

PyPIへの公開手順

  1. 必要なツールのインストール

    pip install build twine
    
  2. パッケージのビルド

    python -m build
    

    これにより dist/ ディレクトリに配布用ファイルが生成されます。

  3. ビルドの確認(オプション)

    # ローカルでテスト
    pip install dist/galr-*.whl
    
    # または、TestPyPIでテスト
    twine upload --repository testpypi dist/*
    
  4. PyPIへのアップロード

    twine upload dist/*
    

    注意: 初回はTestPyPIでテストすることを推奨します:

    twine upload --repository testpypi dist/*
    
  5. バージョン更新

    • pyproject.tomlversion を更新
    • 変更をコミット・タグ付け
    • 再度ビルド・アップロード

ローカル開発環境のセットアップ

# リポジトリのクローン
git clone https://github.com/yut0takagi/galr.git
cd galr

# 開発環境のセットアップ
pip install -e ".[dev]"

テスト

# テストの実行(pytestが必要)
pytest

# カバレッジ付きテスト
pytest --cov=galr --cov-report=html

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

galr-0.1.0.tar.gz (9.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

galr-0.1.0-py3-none-any.whl (7.7 kB view details)

Uploaded Python 3

File details

Details for the file galr-0.1.0.tar.gz.

File metadata

  • Download URL: galr-0.1.0.tar.gz
  • Upload date:
  • Size: 9.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.6

File hashes

Hashes for galr-0.1.0.tar.gz
Algorithm Hash digest
SHA256 96f7ebe7c44e48594d24c242a46d5dcbd8fa91ebd27b96d638f32456eb65ee99
MD5 6b7d4f53313272068eef93b5f92f03a1
BLAKE2b-256 e489f8332bd95c58e0cfb42788e2ccf2c6b47347d46295dc4005be215c9410c3

See more details on using hashes here.

File details

Details for the file galr-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: galr-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 7.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.6

File hashes

Hashes for galr-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 5e4de50b9d8c32114c299cb0badaffcf28ef28970f8e8fe142e699fbfa9d491e
MD5 94957235aa19b6e6b806314147bddf5e
BLAKE2b-256 373181cd060b0bb372d2f6266c974962c60f653860611e0d8f452a1bc7c6e12e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page