Skip to main content

AutoML SDK for Tencent Cloud WeData using FLAML with MLflow integration.

Project description

WeData AutoML

腾讯云 WeData 平台的 AutoML SDK,基于 FLAML 构建,集成 MLflow 进行实验追踪和模型注册。

✨ 功能特性

  • 多任务支持:分类(Classification)、回归(Regression)、时序预测(Forecast)
  • FLAML 驱动:高效的 AutoML 超参数搜索,支持 LightGBM、XGBoost、RandomForest 等估计器
  • MLflow 集成:自动实验追踪、模型日志记录、模型注册
  • Spark 支持:支持 Spark DataFrame 输入,可配合 DLC 使用
  • 特征工程集成:与 WeData 特征工程 SDK 无缝对接
  • Notebook 生成:自动生成可复现的 Jupyter Notebook(分类/回归任务)
  • 并发训练:支持多 Trial 并发执行

📦 安装

# 基础安装
pip install tencent-wedata-auto-ml

# 可选依赖
pip install "tencent-wedata-auto-ml[xgboost]"     # XGBoost 支持
pip install "tencent-wedata-auto-ml[lightgbm]"   # LightGBM 支持
pip install "tencent-wedata-auto-ml[full]"       # 完整安装(推荐)

🚀 快速开始

便捷函数 API

from wedata_automl import classify, regress, forecast

# 分类任务
summary = classify(
    dataset=spark.table("demo.wine_quality"),
    target_col="quality",
    timeout_minutes=10,
    max_trials=100,
    metric="accuracy",
    project_id="your_project_id",
    experiment_name="wine_classification",
    register_model=True,
    model_name="wine_model"
)

# 回归任务
summary = regress(
    dataset=df,
    target_col="price",
    timeout_minutes=10,
    metric="r2",
    project_id="your_project_id"
)

# 时序预测任务
summary = forecast(
    dataset=spark.table("demo.sales_data"),
    target_col="sales",
    time_col="date",
    horizon=30,
    frequency="D",
    timeout_minutes=60,
    project_id="your_project_id"
)

任务类 API

from wedata_automl import Classifier, Regressor, Forecast

# 使用 Classifier 类
classifier = Classifier()
summary = classifier.fit(
    dataset=df,
    target_col="label",
    timeout_minutes=10,
    project_id="your_project_id"
)

# 使用 Regressor 类
regressor = Regressor()
summary = regressor.fit(
    dataset=df,
    target_col="target",
    timeout_minutes=10,
    project_id="your_project_id"
)

查看结果

print(summary)
# AutoMLSummary:
#   Experiment ID: 42
#   Run ID: abc123...
#   Best Trial Run ID: def456...
#   Model URI: runs:/abc123.../model
#   Best Estimator: lgbm
#   Metrics:
#     accuracy: 0.9500
#     f1: 0.9400

# 生成可复现 Notebook(仅分类/回归任务)
summary.generate_notebook("best_model.ipynb")

# 保存 Notebook 到 WeData 平台
summary.save_notebook_to_wedata()

📋 主要参数

参数 说明 默认值
dataset 数据集(Pandas/Spark DataFrame 或表名) 必填
target_col 目标列名 必填
project_id WeData 项目 ID 必填
timeout_minutes 超时时间(分钟) 5
max_trials 最大试验次数 100
metric 评估指标 auto
estimator_list 估计器列表 None(使用全部)
register_model 是否注册模型 True
model_name 注册模型名称 None
experiment_name MLflow 实验名称 None
custom_hp 自定义超参数搜索空间 None

评估指标

分类任务accuracy, f1, log_loss, roc_auc, precision, recall

回归任务r2, mse, rmse, mae, mape

时序预测smape, mse, rmse, mae, mdape

估计器列表

分类/回归lgbm, xgboost, rf, extra_tree, lrl1(仅分类)

时序预测prophet, arima, sarimax

⚙️ 环境配置

# 必需:项目 ID
export WEDATA_PROJECT_ID="your_project_id"

# 必需:MLflow Tracking URI
export MLFLOW_TRACKING_URI="http://your-mlflow-server:5000"

# 可选:腾讯云密钥(用于保存 Notebook 到 WeData)
export TENCENTCLOUD_SECRET_ID="your_secret_id"
export TENCENTCLOUD_SECRET_KEY="your_secret_key"

📁 项目结构

wedata-automl/
├── src/wedata_automl/
│   ├── api.py              # 便捷函数 (classify, regress, forecast)
│   ├── summary.py          # AutoMLSummary 结果对象
│   ├── driver.py           # AutoML 驱动程序
│   ├── tasks/              # 任务类
│   │   ├── classifier.py   # Classifier 类
│   │   ├── regressor.py    # Regressor 类
│   │   └── forecast.py     # Forecast 类
│   ├── engines/            # 训练引擎
│   │   ├── flaml_trainer.py # FLAML 训练器
│   │   └── trial_hook.py   # Trial 日志钩子
│   ├── notebook_generator/ # Notebook 生成器
│   └── utils/              # 工具函数
├── templates/              # Driver 模板
│   ├── classification_driver_template.py
│   └── forecast_driver_template.py
├── docs/                   # 文档
└── examples/               # 示例代码

📚 文档

使用指南

技术参考

⚠️ 注意事项

  • Python >= 3.9
  • Project ID 必填:通过 project_id 参数或 WEDATA_PROJECT_ID 环境变量配置
  • MLflow Tracking URI 必须正确配置

📄 License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tencent_wedata_auto_ml_dev-0.2.76.tar.gz (89.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tencent_wedata_auto_ml_dev-0.2.76-py3-none-any.whl (106.3 kB view details)

Uploaded Python 3

File details

Details for the file tencent_wedata_auto_ml_dev-0.2.76.tar.gz.

File metadata

File hashes

Hashes for tencent_wedata_auto_ml_dev-0.2.76.tar.gz
Algorithm Hash digest
SHA256 941e4f3c5afee50dfc201696bd7e754cb6b2b605bff3c9b61282d93019eea768
MD5 29353a803db3d8bea1b67a05f47db73f
BLAKE2b-256 c5759116158dc501f16a9fb29b961b5b9350d6443305fbcb2edbf2e4370ab457

See more details on using hashes here.

File details

Details for the file tencent_wedata_auto_ml_dev-0.2.76-py3-none-any.whl.

File metadata

File hashes

Hashes for tencent_wedata_auto_ml_dev-0.2.76-py3-none-any.whl
Algorithm Hash digest
SHA256 f35b3fcb1b4292e61863dde267569ff91563764e427891b5c612362e812547e1
MD5 f877d5ef1d3a679bb6371f8a1f1c92ac
BLAKE2b-256 3e178016df660eca2cda251c75fe9ba4859609888a57c206b5e9267c184f0445

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page