Skip to main content

AutoML SDK for Tencent Cloud WeData using FLAML with MLflow integration.

Project description

WeData AutoML

腾讯云 WeData 平台的 AutoML SDK,基于 FLAML 构建,集成 MLflow 进行实验追踪和模型注册。

✨ 功能特性

  • 多任务支持:分类(Classification)、回归(Regression)、时序预测(Forecast)
  • FLAML 驱动:高效的 AutoML 超参数搜索,支持 LightGBM、XGBoost、RandomForest 等估计器
  • MLflow 集成:自动实验追踪、模型日志记录、模型注册
  • Spark 支持:支持 Spark DataFrame 输入,可配合 DLC 使用
  • 特征工程集成:与 WeData 特征工程 SDK 无缝对接
  • Notebook 生成:自动生成可复现的 Jupyter Notebook(分类/回归任务)
  • 并发训练:支持多 Trial 并发执行

📦 安装

# 基础安装
pip install tencent-wedata-auto-ml

# 可选依赖
pip install "tencent-wedata-auto-ml[xgboost]"     # XGBoost 支持
pip install "tencent-wedata-auto-ml[lightgbm]"   # LightGBM 支持
pip install "tencent-wedata-auto-ml[full]"       # 完整安装(推荐)

🚀 快速开始

便捷函数 API

from wedata_automl import classify, regress, forecast

# 分类任务
summary = classify(
    dataset=spark.table("demo.wine_quality"),
    target_col="quality",
    timeout_minutes=10,
    max_trials=100,
    metric="accuracy",
    project_id="your_project_id",
    experiment_name="wine_classification",
    register_model=True,
    model_name="wine_model"
)

# 回归任务
summary = regress(
    dataset=df,
    target_col="price",
    timeout_minutes=10,
    metric="r2",
    project_id="your_project_id"
)

# 时序预测任务
summary = forecast(
    dataset=spark.table("demo.sales_data"),
    target_col="sales",
    time_col="date",
    horizon=30,
    frequency="D",
    timeout_minutes=60,
    project_id="your_project_id"
)

任务类 API

from wedata_automl import Classifier, Regressor, Forecast

# 使用 Classifier 类
classifier = Classifier()
summary = classifier.fit(
    dataset=df,
    target_col="label",
    timeout_minutes=10,
    project_id="your_project_id"
)

# 使用 Regressor 类
regressor = Regressor()
summary = regressor.fit(
    dataset=df,
    target_col="target",
    timeout_minutes=10,
    project_id="your_project_id"
)

查看结果

print(summary)
# AutoMLSummary:
#   Experiment ID: 42
#   Run ID: abc123...
#   Best Trial Run ID: def456...
#   Model URI: runs:/abc123.../model
#   Best Estimator: lgbm
#   Metrics:
#     accuracy: 0.9500
#     f1: 0.9400

# 生成可复现 Notebook(仅分类/回归任务)
summary.generate_notebook("best_model.ipynb")

# 保存 Notebook 到 WeData 平台
summary.save_notebook_to_wedata()

📋 主要参数

参数 说明 默认值
dataset 数据集(Pandas/Spark DataFrame 或表名) 必填
target_col 目标列名 必填
project_id WeData 项目 ID 必填
timeout_minutes 超时时间(分钟) 5
max_trials 最大试验次数 100
metric 评估指标 auto
estimator_list 估计器列表 None(使用全部)
register_model 是否注册模型 True
model_name 注册模型名称 None
experiment_name MLflow 实验名称 None
custom_hp 自定义超参数搜索空间 None

评估指标

分类任务accuracy, f1, log_loss, roc_auc, precision, recall

回归任务r2, mse, rmse, mae, mape

时序预测smape, mse, rmse, mae, mdape

估计器列表

分类/回归lgbm, xgboost, rf, extra_tree, lrl1(仅分类)

时序预测prophet, arima, sarimax

⚙️ 环境配置

# 必需:项目 ID
export WEDATA_PROJECT_ID="your_project_id"

# 必需:MLflow Tracking URI
export MLFLOW_TRACKING_URI="http://your-mlflow-server:5000"

# 可选:腾讯云密钥(用于保存 Notebook 到 WeData)
export TENCENTCLOUD_SECRET_ID="your_secret_id"
export TENCENTCLOUD_SECRET_KEY="your_secret_key"

📁 项目结构

wedata-automl/
├── src/wedata_automl/
│   ├── api.py              # 便捷函数 (classify, regress, forecast)
│   ├── summary.py          # AutoMLSummary 结果对象
│   ├── driver.py           # AutoML 驱动程序
│   ├── tasks/              # 任务类
│   │   ├── classifier.py   # Classifier 类
│   │   ├── regressor.py    # Regressor 类
│   │   └── forecast.py     # Forecast 类
│   ├── engines/            # 训练引擎
│   │   ├── flaml_trainer.py # FLAML 训练器
│   │   └── trial_hook.py   # Trial 日志钩子
│   ├── notebook_generator/ # Notebook 生成器
│   └── utils/              # 工具函数
├── templates/              # Driver 模板
│   ├── classification_driver_template.py
│   └── forecast_driver_template.py
├── docs/                   # 文档
└── examples/               # 示例代码

📚 文档

使用指南

技术参考

⚠️ 注意事项

  • Python >= 3.9
  • Project ID 必填:通过 project_id 参数或 WEDATA_PROJECT_ID 环境变量配置
  • MLflow Tracking URI 必须正确配置

📄 License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tencent_wedata_auto_ml_dev-0.2.77.tar.gz (88.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tencent_wedata_auto_ml_dev-0.2.77-py3-none-any.whl (106.1 kB view details)

Uploaded Python 3

File details

Details for the file tencent_wedata_auto_ml_dev-0.2.77.tar.gz.

File metadata

File hashes

Hashes for tencent_wedata_auto_ml_dev-0.2.77.tar.gz
Algorithm Hash digest
SHA256 360a64f0f500ed3f7c7834636e4a38a32e4d1def5b48de5cea276eac45cf4439
MD5 2c7b41b62256cde5a1de97f7474e7a1a
BLAKE2b-256 67eab9ea7888f2bce50fa2c300f9f4740ff45369da5f66ed56db6860d2a4799f

See more details on using hashes here.

File details

Details for the file tencent_wedata_auto_ml_dev-0.2.77-py3-none-any.whl.

File metadata

File hashes

Hashes for tencent_wedata_auto_ml_dev-0.2.77-py3-none-any.whl
Algorithm Hash digest
SHA256 652da8e49372598d7a44940b5c7ad51c7587b4cddf61ec98aa0a5b91dbb62df7
MD5 b11cd809b3fb8daec378afc9934cb62b
BLAKE2b-256 aa40adb27fc8d2343b2e0f8865366674566b67b1e8848ed3469634a6f298f477

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page