Skip to main content

AutoML SDK for Tencent Cloud WeData using FLAML with MLflow integration.

Project description

WeData AutoML

腾讯云 WeData 平台的 AutoML SDK,基于 FLAML 构建,集成 MLflow 进行实验追踪和模型注册。

✨ 功能特性

  • 多任务支持:分类(Classification)、回归(Regression)、时序预测(Forecast)
  • FLAML 驱动:高效的 AutoML 超参数搜索,支持 LightGBM、XGBoost、RandomForest 等估计器
  • MLflow 集成:自动实验追踪、模型日志记录、模型注册
  • Spark 支持:支持 Spark DataFrame 输入,可配合 DLC 使用
  • 特征工程集成:与 WeData 特征工程 SDK 无缝对接
  • Notebook 生成:自动生成可复现的 Jupyter Notebook(分类/回归任务)
  • 并发训练:支持多 Trial 并发执行

📦 安装

# 基础安装
pip install tencent-wedata3-automl
pip install mlflow==3.1.0

🚀 快速开始

便捷函数 API

from wedata_automl import classify, regress, forecast

# 分类任务
summary = classify(
    dataset=spark.table("demo.wine_quality"),
    target_col="quality",
    timeout_minutes=10,
    max_trials=100,
    metric="accuracy",
    workspace_id="your_workspace_id",
    experiment_name="wine_classification",
    register_model=True,
    model_name="wine_model"
)

# 回归任务
summary = regress(
    dataset=df,
    target_col="price",
    timeout_minutes=10,
    metric="r2",
    workspace_id="your_workspace_id"
)

# 时序预测任务
summary = forecast(
    dataset=spark.table("demo.sales_data"),
    target_col="sales",
    time_col="date",
    horizon=30,
    frequency="D",
    timeout_minutes=60,
    workspace_id="your_workspace_id"
)

任务类 API

from wedata_automl import Classifier, Regressor, Forecast

# 使用 Classifier 类
classifier = Classifier()
summary = classifier.fit(
    dataset=df,
    target_col="label",
    timeout_minutes=10,
    workspace_id="your_workspace_id"
)

# 使用 Regressor 类
regressor = Regressor()
summary = regressor.fit(
    dataset=df,
    target_col="target",
    timeout_minutes=10,
    workspace_id="your_workspace_id"
)

查看结果

print(summary)
# AutoMLSummary:
#   Experiment ID: 42
#   Run ID: abc123...
#   Best Trial Run ID: def456...
#   Model URI: runs:/abc123.../model
#   Best Estimator: lgbm
#   Metrics:
#     accuracy: 0.9500
#     f1: 0.9400

# 生成可复现 Notebook(仅分类/回归任务)
summary.generate_notebook("best_model.ipynb")

# 保存 Notebook 到 WeData 平台
summary.save_notebook_to_wedata()

📋 主要参数

参数 说明 默认值
dataset 数据集(Pandas/Spark DataFrame 或表名) 必填
target_col 目标列名 必填
workspace_id WeData 空间 ID 必填
timeout_minutes 超时时间(分钟) 5
max_trials 最大试验次数 100
metric 评估指标 auto
estimator_list 估计器列表 None(使用全部)
register_model 是否注册模型 True
model_name 注册模型名称 None
experiment_name MLflow 实验名称 None
custom_hp 自定义超参数搜索空间 None

评估指标

分类任务accuracy, f1, log_loss, roc_auc, precision, recall

回归任务r2, mse, rmse, mae, mape

时序预测smape, mse, rmse, mae, mdape

估计器列表

分类/回归lgbm, xgboost, rf, extra_tree, lrl1(仅分类)

时序预测prophet, arima, sarimax

⚙️ 环境配置

# 必需:项目 ID
export WEDATA_WORKSPACE_ID="your_workspace_id"

# 必需:MLflow Tracking URI
export MLFLOW_TRACKING_URI="http://your-mlflow-server:5000"

# 可选:腾讯云密钥(用于保存 Notebook 到 WeData)
export TENCENTCLOUD_SECRET_ID="your_secret_id"
export TENCENTCLOUD_SECRET_KEY="your_secret_key"

📁 项目结构

wedata-automl/
├── src/wedata_automl/
│   ├── api.py              # 便捷函数 (classify, regress, forecast)
│   ├── summary.py          # AutoMLSummary 结果对象
│   ├── driver.py           # AutoML 驱动程序
│   ├── tasks/              # 任务类
│   │   ├── classifier.py   # Classifier 类
│   │   ├── regressor.py    # Regressor 类
│   │   └── forecast.py     # Forecast 类
│   ├── engines/            # 训练引擎
│   │   ├── flaml_trainer.py # FLAML 训练器
│   │   └── trial_hook.py   # Trial 日志钩子
│   ├── notebook_generator/ # Notebook 生成器
│   └── utils/              # 工具函数
├── templates/              # Driver 模板
│   ├── classification_driver_template.py
│   └── forecast_driver_template.py
├── docs/                   # 文档
└── examples/               # 示例代码

📚 文档

使用指南

技术参考

⚠️ 注意事项

  • Python >= 3.9
  • WorkSpace ID 必填:通过 workspace_id 参数或 WEDATA_WORKSPACE_ID 环境变量配置
  • MLflow Tracking URI 必须正确配置

📄 License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tencent_wedata3_auto_ml-0.0.3.tar.gz (91.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tencent_wedata3_auto_ml-0.0.3-py3-none-any.whl (110.4 kB view details)

Uploaded Python 3

File details

Details for the file tencent_wedata3_auto_ml-0.0.3.tar.gz.

File metadata

  • Download URL: tencent_wedata3_auto_ml-0.0.3.tar.gz
  • Upload date:
  • Size: 91.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.6

File hashes

Hashes for tencent_wedata3_auto_ml-0.0.3.tar.gz
Algorithm Hash digest
SHA256 6700c822e4b05463484e4bd4c98a97b915bf8fb5e0c25d50de767b042e51f18a
MD5 da72e5365867e6734af917a4a5b3ce1f
BLAKE2b-256 17cf1f87aad61eea84191621aa40e3373e042687c2280abe6e12485ee7dd8927

See more details on using hashes here.

File details

Details for the file tencent_wedata3_auto_ml-0.0.3-py3-none-any.whl.

File metadata

File hashes

Hashes for tencent_wedata3_auto_ml-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 5b41c77221fa444835b20851b276a42f78b82d5ec804bd033cc10b0655b4dcf1
MD5 9663b20595395ced452b9e1cf70dae2a
BLAKE2b-256 42e3bf2dd505ca6394bf3ac8de638e45537279806feec7de7689764bafa4045e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page