Gated Asymmetric Linear Regression - 特徴量に応じてコスト比を学習する回帰モデル
Project description
Gated Asymmetric Linear Regression (GALR)
「上振れ(過大予測)」と「下振れ(過小予測)」のペナルティを固定せず、特徴量に応じて学習する"ゲート付き非対称線形回帰"パッケージ。
概要
従来の cost-sensitive 回帰(上振れ/下振れの重み固定)ではなく、状況(特徴量 x)に応じてコスト比が変化する現実の意思決定をモデル化します。
- sklearn 互換 API(
BaseEstimator,RegressorMixin) - プロダクション運用(再学習・評価・監視)まで想定
- 最小限の依存関係(numpy, scikit-learn)
インストール
pip install galr
開発版をインストールする場合:
git clone https://github.com/yourusername/galr.git
cd galr
pip install -e .
使用方法
from galr import GALRRegressor
import numpy as np
# データの準備
X = np.random.randn(100, 5)
y = np.random.randn(100)
# モデルの学習
model = GALRRegressor(
gate='linear',
fit_intercept=True,
optimizer='sgd',
lr=0.01,
n_iter=1000,
lambda_beta=0.01,
lambda_gate=0.01,
epsilon=1e-6,
random_state=42
)
model.fit(X, y)
# 予測
y_pred = model.predict(X)
# ゲート関数の値を取得(オプション)
gate_values = model.get_gate_values(X)
モデル詳細
予測器
\hat{y} = x^\top \beta + b
ゲート関数
ゲート関数 $g(x)$ が「この状況では下振れが痛い/上振れが痛い」を学習します。
w_\mathrm{under}(x) = \mathrm{softplus}(g(x)) + \epsilon
w_\mathrm{over}(x) = \mathrm{softplus}(-g(x)) + \epsilon
損失関数
L(\beta, b, \theta) = \frac{1}{n} \sum_i \Big[ \mathbb{1}(e_i > 0) \cdot w_\mathrm{under}(x_i) + \mathbb{1}(e_i < 0) \cdot w_\mathrm{over}(x_i) \Big] e_i^2 + \lambda_\beta \|\beta\|_2^2 + \lambda_g \|\theta\|_2^2
パラメータ
gate:'linear'|'mlp'(現在は'linear'のみ対応)fit_intercept: bool - 切片を学習するかoptimizer:'sgd'|'adam'(現在は'sgd'のみ対応)lr: float - 学習率n_iter: int - イテレーション数tol: float - 収束判定の閾値lambda_beta: float - 回帰係数のL2正則化係数lambda_gate: float - ゲートパラメータのL2正則化係数epsilon: float - softplusの下限値standardize: bool - 内部でStandardScalerを使用するかrandom_state: int - 乱数シード
ライセンス
MIT License
開発状況
現在は MVP(最小実装)段階です。将来的には以下を追加予定:
- MLPゲートの実装
- Adamオプティマイザの実装
- より高度な最適化手法
- 詳細なドキュメントとチュートリアル
CI/CD
このプロジェクトはGitHub ActionsによるCI/CDを採用しています。
継続的インテグレーション (CI)
Pull Request作成時に自動実行:
- テスト: 複数のPythonバージョン(3.8-3.13)でテストを実行
- リント: flake8によるコード品質チェック
- 型チェック: mypyによる型チェック
- ビルド: パッケージのビルドと検証
- バージョンチェック: バージョン形式の検証
継続的デプロイ (CD)
- developブランチからmainへのマージ: 自動的にバージョンをインクリメントしてPyPIに公開(推奨)
- mainブランチへの直接プッシュ: 自動的にバージョンをインクリメントしてPyPIに公開
- 手動実行: GitHub ActionsのUIからバージョンを指定してリリース可能
詳細は RELEASE.md を参照してください。
開発者向け情報
PyPIへの公開手順
-
必要なツールのインストール
pip install build twine
-
パッケージのビルド
python -m build
これにより
dist/ディレクトリに配布用ファイルが生成されます。 -
ビルドの確認(オプション)
# ローカルでテスト pip install dist/galr-*.whl # または、TestPyPIでテスト twine upload --repository testpypi dist/*
-
PyPIへのアップロード
twine upload dist/*
注意: 初回はTestPyPIでテストすることを推奨します:
twine upload --repository testpypi dist/*
-
バージョン更新
pyproject.tomlのversionを更新- 変更をコミット・タグ付け
- 再度ビルド・アップロード
ローカル開発環境のセットアップ
# リポジトリのクローン
git clone https://github.com/yut0takagi/galr.git
cd galr
# 開発環境のセットアップ
pip install -e ".[dev]"
テスト
# テストの実行
pytest
# カバレッジ付きテスト
pytest --cov=galr --cov-report=html
# 特定のテストのみ実行
pytest tests/test_galr.py::TestGALRRegressor::test_fit_predict
コード品質チェック
# リントチェック
flake8 src/ tests/
# コードフォーマットチェック
black --check src/ tests/
# コードフォーマット適用
black src/ tests/
# 型チェック
mypy src/galr
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file galr-0.1.1.tar.gz.
File metadata
- Download URL: galr-0.1.1.tar.gz
- Upload date:
- Size: 11.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.25
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1e7e8ff8c4133ad5b6d8ffee72a7ffa7ce56c9e36bb7ceae0b6f47f30d7d2b8e
|
|
| MD5 |
8a05f01e7bf51e380979e9cec189e2dc
|
|
| BLAKE2b-256 |
3951d1f67544cb326c37ea1fdb1118791207feba28abb7197a57d663fe735fcb
|
File details
Details for the file galr-0.1.1-py3-none-any.whl.
File metadata
- Download URL: galr-0.1.1-py3-none-any.whl
- Upload date:
- Size: 8.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.25
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2078cd404d2a81895d236b476e45faeb74f6b8c74f3d6480a64ba0e84a265ed1
|
|
| MD5 |
a498a99437951bd0333143633e962067
|
|
| BLAKE2b-256 |
db9035ed48ddc4d9c495a20476b4a0a90b7a4c73e25f2289b6e8b2b381fcd05f
|