Some Rank/Multi-task model implemented by Pytorch
Project description
Rec PanGu
1.开源定位
- 使用pytorch对经典的rank/多任务模型进行实现,并且对外提供统一调用的API接口,极大的降低了使用Rank/多任务模型的时间成本
- 该项目使用了pytorch来实现我们的各种模型,以便于初学推荐系统的人可以更好的理解算法的核心思想
- 由于已经有了很多类似的优秀的开源,我们这里对那些十分通用的模块参考了已有的开源,十分感谢这些开源贡献者的贡献
2.Rank模型
这里目前支持以下Rank模型
WDL
DeepFM
NFM
FiBiNet
AFM
AFN
AOANet
AutoInt
CCPM
LR
FM
xDeepFM
3.多任务模型
目前支持以下多任务模型
AITM
ShareBottom
ESSM
MMOE
OMOE
MLMMOE
4.Demo
我们的Rank和多任务模型所对外暴露的接口十分相似,我们下面会分别给出Rank和多任务模型的demo
4.1 Rank Demo
#声明数据schema
import torch
from rec_pangu.dataset import get_dataloader
from rec_pangu.models.ranking import WDL, DeepFM, NFM, FiBiNet, AFM, AFN, AOANet, AutoInt, CCPM, LR, FM, xDeepFM
from rec_pangu.trainer import RankTraniner
import pandas as pd
if __name__=='__main__':
df = pd.read_csv('sample_data/ranking_sample_data.csv')
print(df.head())
#声明数据schema
schema={
"sparse_cols":['user_id','item_id','item_type','dayofweek','is_workday','city','county',
'town','village','lbs_city','lbs_district','hardware_platform','hardware_ischarging',
'os_type','network_type','position'],
"dense_cols" : ['item_expo_1d','item_expo_7d','item_expo_14d','item_expo_30d','item_clk_1d',
'item_clk_7d','item_clk_14d','item_clk_30d','use_duration'],
"label_col":'click',
}
#准备数据,这里只选择了100条数据,所以没有切分数据集
train_df = df
valid_df = df
test_df = df
#声明使用的device
device = torch.device('cpu')
#获取dataloader
train_loader, valid_loader, test_loader, enc_dict = get_dataloader(train_df, valid_df, test_df, schema)
#声明模型,排序模型目前支持:WDL, DeepFM, NFM, FiBiNet, AFM, AFN, AOANet, AutoInt, CCPM, LR, FM, xDeepFM
model = xDeepFM(enc_dict=enc_dict)
#声明Trainer
trainer = RankTraniner(num_task=1)
#训练模型
trainer.fit(model, train_loader, valid_loader, epoch=5, lr=1e-3, device=device)
#保存模型权重
trainer.save_model(model, './model_ckpt')
#模型验证
test_metric = trainer.evaluate_model(model, test_loader, device=device)
print('Test metric:{}'.format(test_metric))
这里的schema主要记录数据集的信息,主要包括离散特征的列表('sparse_cols'),连续特征列表('dense_cols'),标签列('label_cols')
4.2 多任务模型Demo
import torch
from rec_pangu.dataset import get_dataloader
from rec_pangu.models.multi_task import AITM,ShareBottom,ESSM,MMOE,OMOE,MLMMOE
from rec_pangu.trainer import RankTraniner
import pandas as pd
if __name__=='__main__':
df = pd.read_csv('sample_data/multi_task_sample_data.csv')
print(df.head())
#声明数据schema
schema={
"sparse_cols":['user_id','item_id','item_type','dayofweek','is_workday','city','county',
'town','village','lbs_city','lbs_district','hardware_platform','hardware_ischarging',
'os_type','network_type','position'],
"dense_cols" : ['item_expo_1d','item_expo_7d','item_expo_14d','item_expo_30d','item_clk_1d',
'item_clk_7d','item_clk_14d','item_clk_30d','use_duration'],
"label_col":['click','scroll'],
}
#准备数据,这里只选择了100条数据,所以没有切分数据集
train_df = df
valid_df = df
test_df = df
#声明使用的device
device = torch.device('cpu')
#获取dataloader
train_loader, valid_loader, test_loader, enc_dict = get_dataloader(train_df, valid_df, test_df, schema)
#声明模型,多任务模型目前支持:AITM,ShareBottom,ESSM,MMOE,OMOE,MLMMOE
model = AITM(enc_dict=enc_dict)
#声明Trainer
trainer = RankTraniner(num_task=2)
#训练模型
trainer.fit(model, train_loader, valid_loader, epoch=5, lr=1e-3, device=device)
#保存模型权重
trainer.save_model(model, './model_ckpt')
#模型验证
test_metric = trainer.evaluate_model(model, test_loader, device=device)
print('Test metric:{}'.format(test_metric))
这里的schema主要记录数据集的信息,主要包括离散特征的列表('sparse_cols'),连续特征列表('dense_cols'),标签列('label_cols'),注意在多任务模型中,标签列的值为每一个任务的列名
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
rec_pangu-0.0.1.tar.gz
(3.8 kB
view hashes)
Built Distribution
Close
Hashes for rec_pangu-0.0.1-py2.py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9f8c9a86c8ae50a26d717edde7b24236daad39e816be3f3cc1143ce1fc74e747 |
|
MD5 | af85aee39c5b787aeb8cd985ab6797f8 |
|
BLAKE2b-256 | 44a4803695cc89561d776b393049f3476b56735bf29293d87ec42fd84b2cdc14 |