Skip to main content

AutoML SDK for Tencent Cloud WeData using FLAML with MLflow integration.

Project description

WeData AutoML

腾讯云 WeData 平台的 AutoML SDK,基于 FLAML 构建,集成 MLflow 进行实验追踪和模型注册。

✨ 功能特性

  • 多任务支持:分类(Classification)、回归(Regression)、时序预测(Forecast)
  • FLAML 驱动:高效的 AutoML 超参数搜索,支持 LightGBM、XGBoost、RandomForest 等估计器
  • MLflow 集成:自动实验追踪、模型日志记录、模型注册
  • Spark 支持:支持 Spark DataFrame 输入,可配合 DLC 使用
  • 特征工程集成:与 WeData 特征工程 SDK 无缝对接
  • Notebook 生成:自动生成可复现的 Jupyter Notebook(分类/回归任务)
  • 并发训练:支持多 Trial 并发执行

📦 安装

# 基础安装
pip install tencent-wedata-auto-ml

# 可选依赖
pip install "tencent-wedata-auto-ml[xgboost]"     # XGBoost 支持
pip install "tencent-wedata-auto-ml[lightgbm]"   # LightGBM 支持
pip install "tencent-wedata-auto-ml[full]"       # 完整安装(推荐)

🚀 快速开始

便捷函数 API

from wedata_automl import classify, regress, forecast

# 分类任务
summary = classify(
    dataset=spark.table("demo.wine_quality"),
    target_col="quality",
    timeout_minutes=10,
    max_trials=100,
    metric="accuracy",
    project_id="your_project_id",
    experiment_name="wine_classification",
    register_model=True,
    model_name="wine_model"
)

# 回归任务
summary = regress(
    dataset=df,
    target_col="price",
    timeout_minutes=10,
    metric="r2",
    project_id="your_project_id"
)

# 时序预测任务
summary = forecast(
    dataset=spark.table("demo.sales_data"),
    target_col="sales",
    time_col="date",
    horizon=30,
    frequency="D",
    timeout_minutes=60,
    project_id="your_project_id"
)

任务类 API

from wedata_automl import Classifier, Regressor, Forecast

# 使用 Classifier 类
classifier = Classifier()
summary = classifier.fit(
    dataset=df,
    target_col="label",
    timeout_minutes=10,
    project_id="your_project_id"
)

# 使用 Regressor 类
regressor = Regressor()
summary = regressor.fit(
    dataset=df,
    target_col="target",
    timeout_minutes=10,
    project_id="your_project_id"
)

查看结果

print(summary)
# AutoMLSummary:
#   Experiment ID: 42
#   Run ID: abc123...
#   Best Trial Run ID: def456...
#   Model URI: runs:/abc123.../model
#   Best Estimator: lgbm
#   Metrics:
#     accuracy: 0.9500
#     f1: 0.9400

# 生成可复现 Notebook(仅分类/回归任务)
summary.generate_notebook("best_model.ipynb")

# 保存 Notebook 到 WeData 平台
summary.save_notebook_to_wedata()

📋 主要参数

参数 说明 默认值
dataset 数据集(Pandas/Spark DataFrame 或表名) 必填
target_col 目标列名 必填
project_id WeData 项目 ID 必填
timeout_minutes 超时时间(分钟) 5
max_trials 最大试验次数 100
metric 评估指标 auto
estimator_list 估计器列表 None(使用全部)
register_model 是否注册模型 True
model_name 注册模型名称 None
experiment_name MLflow 实验名称 None
custom_hp 自定义超参数搜索空间 None

评估指标

分类任务accuracy, f1, log_loss, roc_auc, precision, recall

回归任务r2, mse, rmse, mae, mape

时序预测smape, mse, rmse, mae, mdape

估计器列表

分类/回归lgbm, xgboost, rf, extra_tree, lrl1(仅分类)

时序预测prophet, arima, sarimax

⚙️ 环境配置

# 必需:项目 ID
export WEDATA_PROJECT_ID="your_project_id"

# 必需:MLflow Tracking URI
export MLFLOW_TRACKING_URI="http://your-mlflow-server:5000"

# 可选:腾讯云密钥(用于保存 Notebook 到 WeData)
export TENCENTCLOUD_SECRET_ID="your_secret_id"
export TENCENTCLOUD_SECRET_KEY="your_secret_key"

📁 项目结构

wedata-automl/
├── src/wedata_automl/
│   ├── api.py              # 便捷函数 (classify, regress, forecast)
│   ├── summary.py          # AutoMLSummary 结果对象
│   ├── driver.py           # AutoML 驱动程序
│   ├── tasks/              # 任务类
│   │   ├── classifier.py   # Classifier 类
│   │   ├── regressor.py    # Regressor 类
│   │   └── forecast.py     # Forecast 类
│   ├── engines/            # 训练引擎
│   │   ├── flaml_trainer.py # FLAML 训练器
│   │   └── trial_hook.py   # Trial 日志钩子
│   ├── notebook_generator/ # Notebook 生成器
│   └── utils/              # 工具函数
├── templates/              # Driver 模板
│   ├── classification_driver_template.py
│   └── forecast_driver_template.py
├── docs/                   # 文档
└── examples/               # 示例代码

📚 文档

使用指南

技术参考

⚠️ 注意事项

  • Python >= 3.9
  • Project ID 必填:通过 project_id 参数或 WEDATA_PROJECT_ID 环境变量配置
  • MLflow Tracking URI 必须正确配置

📄 License

MIT

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tencent_wedata_auto_ml-0.2.65.tar.gz (77.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tencent_wedata_auto_ml-0.2.65-py3-none-any.whl (93.3 kB view details)

Uploaded Python 3

File details

Details for the file tencent_wedata_auto_ml-0.2.65.tar.gz.

File metadata

  • Download URL: tencent_wedata_auto_ml-0.2.65.tar.gz
  • Upload date:
  • Size: 77.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for tencent_wedata_auto_ml-0.2.65.tar.gz
Algorithm Hash digest
SHA256 b98abc0f83d150777ff9599739867d0c9d0aa0f93a4b8da39bc18fe855fc3e03
MD5 492f8110403d8fc359c9b584a2318385
BLAKE2b-256 4681043cd4b1273b4e5a2b0903a3e97a00abf892c4906352ecc255cbfe9f68b6

See more details on using hashes here.

File details

Details for the file tencent_wedata_auto_ml-0.2.65-py3-none-any.whl.

File metadata

File hashes

Hashes for tencent_wedata_auto_ml-0.2.65-py3-none-any.whl
Algorithm Hash digest
SHA256 6ec0df696befbc51419fb48841b1565b2d83462b02b8887522dc5e361a97ae49
MD5 2ef45ae9532e5aed32b172adfffee5a9
BLAKE2b-256 82f7d0242a1640d34085eb0bb8953e059a1d41ec98fbb24657c3020995a459e4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page