TabAML is a framework for tabular data classification models
Project description
TabAML - 表格数据机器学习框架
TabAML (Tabular Machine Learning) 是一个基于PyTorch的表格数据分类框架,支持二分类和多分类任务。
特性
- 🔥 二分类模型: 支持二分类任务,包含完整的训练和评估流程
- 🌟 多分类模型: 支持多分类任务,自动检测类别数量
- 📊 丰富的评估指标: 支持准确率、F1分数(macro/micro/weighted)等多种评估指标
- ⚡ 早停机制: 防止过拟合,基于验证集性能自动停止训练
- 🛠️ 可配置性: 通过超参数字典灵活配置模型结构和训练参数
- 📈 自动化训练: 包含数据预处理、模型训练、验证和评估的完整流程
安装
确保已安装以下依赖:
pip install torch pandas numpy
快速开始
交互式运行
python main.py
然后选择要运行的模型类型:
- 1: 二分类模型
- 2: 多分类模型
二分类模型
from binary_cls import BinaryClsModel
import pandas as pd
# 加载数据
df_train = pd.read_csv("sample/binary_cls_train.csv")
df_val = pd.read_csv("sample/binary_cls_val.csv")
X_train = df_train.drop(columns=['label'])
y_train = df_train['label']
X_val = df_val.drop(columns=['label'])
y_val = df_val['label']
# 创建模型
model = BinaryClsModel(
hyperparameters={
'learning_rate': 0.01,
'batch_size': 32,
'num_layers': 3,
'hidden_size': 64,
'dropout_prob': 0.2,
'patience': 10,
'use_early_stopping': True
}
)
# 训练模型
model.fit(X=X_train, y=y_train, X_val=X_val, y_val=y_val)
# 评估模型
accuracy = model.score(X_val, y_val, metric='accuracy')
f1 = model.score(X_val, y_val, metric='f1')
print(f"验证集准确率: {accuracy:.4f}")
print(f"验证集F1分数: {f1:.4f}")
多分类模型
from multi_cls import MultiClassifierModel
import pandas as pd
# 加载数据
df_train = pd.read_csv("sample/multi_cls_train.csv")
df_val = pd.read_csv("sample/multi_cls_val.csv")
X_train = df_train.drop(columns=['label'])
y_train = df_train['label']
X_val = df_val.drop(columns=['label'])
y_val = df_val['label']
# 创建模型
model = MultiClassifierModel(
hyperparameters={
'learning_rate': 0.01,
'batch_size': 32,
'num_layers': 3,
'hidden_size': 64,
'dropout_prob': 0.2,
'patience': 10,
'use_early_stopping': True,
'f1_average': 'macro' # F1计算方式: 'macro', 'micro', 'weighted'
}
)
# 训练模型
model.fit(X=X_train, y=y_train, X_val=X_val, y_val=y_val)
# 评估模型
accuracy = model.score(X_val, y_val, metric='accuracy')
f1_macro = model.score(X_val, y_val, metric='f1_macro')
f1_micro = model.score(X_val, y_val, metric='f1_micro')
f1_weighted = model.score(X_val, y_val, metric='f1_weighted')
print(f"验证集准确率: {accuracy:.4f}")
print(f"验证集F1 (macro): {f1_macro:.4f}")
print(f"验证集F1 (micro): {f1_micro:.4f}")
print(f"验证集F1 (weighted): {f1_weighted:.4f}")
模型结构
神经网络架构
- 输入层: 自动适应特征维度
- 隐藏层: 可配置的层数和神经元数量
- 激活函数: ReLU
- 正则化: Dropout
- 输出层:
- 二分类: 2个神经元 + CrossEntropyLoss
- 多分类: 自动检测类别数量 + CrossEntropyLoss
可配置超参数
| 参数 | 默认值 | 说明 |
|---|---|---|
num_epochs |
200 | 最大训练轮数 |
learning_rate |
0.01 | 学习率 |
batch_size |
32 | 批次大小 |
num_layers |
3 | 神经网络层数 |
hidden_size |
128 | 隐藏层神经元数量 |
dropout_prob |
0.1 | Dropout概率 |
patience |
3 | 早停耐心参数 |
use_early_stopping |
True | 是否使用早停 |
f1_average |
'macro' | F1分数计算方式(仅多分类) |
评估指标
二分类
- 准确率 (Accuracy): 正确预测的样本比例
- F1分数: 精确率和召回率的调和平均数
多分类
- 准确率 (Accuracy): 正确预测的样本比例
- F1分数 (macro): 每个类别F1分数的算术平均
- F1分数 (micro): 全局计算的F1分数
- F1分数 (weighted): 按类别支持数加权的F1分数
数据格式
输入数据要求
- 特征: 数值型特征,支持pandas DataFrame或numpy数组
- 标签:
- 二分类: 0和1
- 多分类: 0, 1, 2, ..., n-1 (连续整数)
示例数据
项目包含示例数据:
sample/binary_cls_train.csv/sample/binary_cls_val.csv: 二分类数据sample/multi_cls_train.csv/sample/multi_cls_val.csv: 四分类数据
文件结构
tabaml/
├── main.py # 主入口文件
├── binary_cls.py # 二分类模型实现
├── multi_cls.py # 多分类模型实现
├── metrics.py # 评估指标实现
├── sample/ # 示例数据
│ ├── binary_cls_train.csv
│ ├── binary_cls_val.csv
│ ├── multi_cls_train.csv
│ └── multi_cls_val.csv
└── README.md
注意事项
- GPU支持: 代码自动检测CUDA可用性,优先使用GPU训练
- 数据预处理: 确保数据已经完成标准化、缺失值处理等预处理步骤
- 内存使用: 大数据集建议调整batch_size参数
- 类别平衡: 不平衡数据集可能需要额外的处理策略
许可证
本项目使用 MIT 许可证。
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
tabaml-0.2.0.tar.gz
(18.6 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
tabaml-0.2.0-py3-none-any.whl
(18.8 kB
view details)
File details
Details for the file tabaml-0.2.0.tar.gz.
File metadata
- Download URL: tabaml-0.2.0.tar.gz
- Upload date:
- Size: 18.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.9.23
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
54c599b24fefc3a093449e332129e738af85ed8f0b1e88d8a48c54d32799519c
|
|
| MD5 |
1fe207954209c0c588cab787f374fcf6
|
|
| BLAKE2b-256 |
0eb6cbf679fba9317c13ec5f10859e3a38b15c872692936c1f5e25c495ce12b2
|
File details
Details for the file tabaml-0.2.0-py3-none-any.whl.
File metadata
- Download URL: tabaml-0.2.0-py3-none-any.whl
- Upload date:
- Size: 18.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.9.23
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
33bcc486b8ee4bfc09c0907456222182378a5e9435b47e5328e2d08dec887c7c
|
|
| MD5 |
c2ae056dfb832bcb7c67758bdca1167f
|
|
| BLAKE2b-256 |
73e0ff36dedeb927cb4d5f95c67674ac8bd0cc30ee51f05f2c6cf5d62c4a9b58
|