Skip to main content

Some Rank/Multi-task model implemented by Pytorch

Project description

Rec PanGu

stars issues license wakatime Codacy Badge Downloads Downloads Downloads

1.开源定位

  • 使用pytorch对经典的rank/多任务模型进行实现,并且对外提供统一调用的API接口,极大的降低了使用Rank/多任务模型的时间成本
  • 该项目使用了pytorch来实现我们的各种模型,以便于初学推荐系统的人可以更好的理解算法的核心思想
  • 由于已经有了很多类似的优秀的开源,我们这里对那些十分通用的模块参考了已有的开源,十分感谢这些开源贡献者的贡献

2.安装

#最新版
git clone https://github.com/HaSai666/rec_pangu.git
cd rec_pangu
pip install -e . --verbose

#稳定版 
pip install rec_pangu --upgrade

3.Rank模型

模型 论文 年份
WDL Wide & Deep Learning for Recommender Systems 2016
DeepFM DeepFM: A Factorization-Machine based Neural Network for CTR Prediction 2017
NFM Neural Factorization Machines for Sparse Predictive Analytics 2017
FiBiNet FiBiNET: Combining Feature Importance and Bilinear Feature Interaction for Click-Through Rate 2019
AFM Attentional Factorization Machines 2017
AutoInt AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks 2018
CCPM A Convolutional Click Prediction Model 2015
LR / 2019
FM / 2019
xDeepFM xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems 2018
DCN Deep & Cross Network for Ad Click Predictions 2019
MaskNet MaskNet: Introducing Feature-Wise Multiplication to CTR Ranking Models by Instance-Guided Mask 2021

4.多任务模型

模型 论文 年份
MMOE Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts 2018
ShareBottom Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts 2018
ESSM Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate 2018
OMOE Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts 2018
MLMMOE / /
AITM Modeling the Sequential Dependence among Audience Multi-step Conversions with Multi-task Learning in Targeted Display Advertising 2019

5.序列召回模型

目前支持如下类型的序列召回模型:

  • 经典序列召回模型
  • 基于图的序列召回模型
  • 多兴趣序列召回模型
模型 类型 论文 年份
YotubeDNN 经典序列召回 Deep Neural Networks for YouTube Recommendations 2016
Gru4Rec 经典序列召回 Session-based recommendations with recurrent neural networks 2015
Narm 经典序列召回 Neural Attentive Session-based Recommendation 2017
NextItNet 经典序列召回 A Simple Convolutional Generative Network for Next Item 2019
ComirecSA 多兴趣召回 Controllable multi-interest framework for recommendation 2020
ComirecDR 多兴趣召回 Controllable multi-interest framework for recommendation 2020
Mind 多兴趣召回 Multi-Interest Network with Dynamic Routing for Recommendation at Tmall 2019
Re4 多兴趣召回 Re4: Learning to Re-contrast, Re-attend, Re-construct for Multi-interest Recommendation 2022
CMI 多兴趣召回 mproving Micro-video Recommendation via Contrastive Multiple Interests 2022
SRGNN 图序列召回 Session-based Recommendation with Graph Neural Networks 2019
GC-SAN 图序列召回 SGraph Contextualized Self-Attention Network for Session-based Recommendation 2019
NISER 图序列召回 NISER: Normalized Item and Session Representations to Handle Popularity Bias 2019

6.图协同过滤模型

TODO

7.Demo

我们的Rank和多任务模型所对外暴露的接口十分相似,同时我们这里也支持使用wandb来实时监测模型训练指标,我们下面会分别给出Rank,多任务模型,wandb的demo

7.1 排序任务Demo

# 声明数据schema
import torch
from rec_pangu.dataset import get_dataloader
from rec_pangu.models.ranking import WDL, DeepFM, NFM, FiBiNet, AFM, AFN, AOANet, AutoInt, CCPM, LR, FM, xDeepFM
from rec_pangu.trainer import RankTrainer
import pandas as pd

if __name__ == '__main__':
    df = pd.read_csv('sample_data/ranking_sample_data.csv')
    print(df.head())
    # 声明数据schema
    schema = {
        "sparse_cols": ['user_id', 'item_id', 'item_type', 'dayofweek', 'is_workday', 'city', 'county',
                        'town', 'village', 'lbs_city', 'lbs_district', 'hardware_platform', 'hardware_ischarging',
                        'os_type', 'network_type', 'position'],
        "dense_cols": ['item_expo_1d', 'item_expo_7d', 'item_expo_14d', 'item_expo_30d', 'item_clk_1d',
                       'item_clk_7d', 'item_clk_14d', 'item_clk_30d', 'use_duration'],
        "label_col": 'click',
    }
    # 准备数据,这里只选择了100条数据,所以没有切分数据集
    train_df = df
    valid_df = df
    test_df = df

    # 声明使用的device
    device = torch.device('cpu')
    # 获取dataloader
    train_loader, valid_loader, test_loader, enc_dict = get_dataloader(train_df, valid_df, test_df, schema)
    # 声明模型,排序模型目前支持:WDL, DeepFM, NFM, FiBiNet, AFM, AFN, AOANet, AutoInt, CCPM, LR, FM, xDeepFM
    model = xDeepFM(enc_dict=enc_dict)
    # 声明Trainer
    trainer = RankTrainer(num_task=1)
    # 训练模型
    trainer.fit(model, train_loader, valid_loader, epoch=5, lr=1e-3, device=device)
    # 保存模型权重
    trainer.save_model(model, './model_ckpt')
    # 模型验证
    test_metric = trainer.evaluate_model(model, test_loader, device=device)
    print('Test metric:{}'.format(test_metric))

7.2 多任务模型Demo

import torch
from rec_pangu.dataset import get_dataloader
from rec_pangu.models.multi_task import AITM, ShareBottom, ESSM, MMOE, OMOE, MLMMOE
from rec_pangu.trainer import RankTrainer
import pandas as pd

if __name__ == '__main__':
    df = pd.read_csv('sample_data/multi_task_sample_data.csv')
    print(df.head())
    #声明数据schema
    schema = {
        "sparse_cols": ['user_id', 'item_id', 'item_type', 'dayofweek', 'is_workday', 'city', 'county',
                        'town', 'village', 'lbs_city', 'lbs_district', 'hardware_platform', 'hardware_ischarging',
                        'os_type', 'network_type', 'position'],
        "dense_cols": ['item_expo_1d', 'item_expo_7d', 'item_expo_14d', 'item_expo_30d', 'item_clk_1d',
                       'item_clk_7d', 'item_clk_14d', 'item_clk_30d', 'use_duration'],
        "label_col": ['click', 'scroll'],
    }
    #准备数据,这里只选择了100条数据,所以没有切分数据集
    train_df = df
    valid_df = df
    test_df = df

    #声明使用的device
    device = torch.device('cpu')
    #获取dataloader
    train_loader, valid_loader, test_loader, enc_dict = get_dataloader(train_df, valid_df, test_df, schema)
    #声明模型,多任务模型目前支持:AITM,ShareBottom,ESSM,MMOE,OMOE,MLMMOE
    model = AITM(enc_dict=enc_dict)
    #声明Trainer
    trainer = RankTrainer(num_task=2)
    #训练模型
    trainer.fit(model, train_loader, valid_loader, epoch=5, lr=1e-3, device=device)
    #保存模型权重
    trainer.save_model(model, './model_ckpt')
    #模型验证
    test_metric = trainer.evaluate_model(model, test_loader, device=device)
    print('Test metric:{}'.format(test_metric))

7.3 序列召回Demo

import torch
from rec_pangu.dataset import get_dataloader
from rec_pangu.models.sequence import ComirecSA,ComirecDR,MIND,CMI,Re4,NARM,YotubeDNN,SRGNN
from rec_pangu.trainer import SequenceTrainer
from rec_pangu.utils import set_device
import pandas as pd

if __name__=='__main__':
    #声明数据schema
    schema = {
        'user_col': 'user_id',
        'item_col': 'item_id',
        'cate_cols': ['genre'],
        'max_length': 20,
        'time_col': 'timestamp',
        'task_type':'sequence'
    }
    # 模型配置
    config = {
        'embedding_dim': 64,
        'lr': 0.001,
        'K': 1,
        'device':-1,
    }
    config['device'] = set_device(config['device'])
    config.update(schema)

    #样例数据
    train_df = pd.read_csv('./sample_data/sample_train.csv')
    valid_df = pd.read_csv('./sample_data/sample_valid.csv')
    test_df = pd.read_csv('./sample_data/sample_test.csv')

    #声明使用的device
    device = torch.device('cpu')
    #获取dataloader
    train_loader, valid_loader, test_loader, enc_dict = get_dataloader(train_df, valid_df, test_df, schema, batch_size=50)
    #声明模型,序列召回模型模型目前支持: ComirecSA,ComirecDR,MIND,CMI,Re4,NARM,YotubeDNN,SRGNN
    model = ComirecSA(enc_dict=enc_dict,config=config)
    #声明Trainer
    trainer = SequenceTrainer(model_ckpt_dir='./model_ckpt')
    #训练模型
    trainer.fit(model, train_loader, valid_loader, epoch=500, lr=1e-3, device=device,log_rounds=10,
                use_earlystoping=True, max_patience=5, monitor_metric='recall@20',)
    #保存模型权重和enc_dict
    trainer.save_all(model, enc_dict, './model_ckpt')
    #模型验证
    test_metric = trainer.evaluate_model(model, test_loader, device=device)

7.4 图协调过滤Demo

TODO

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rec_pangu-0.3.5-py3-none-any.whl (106.6 kB view details)

Uploaded Python 3

File details

Details for the file rec_pangu-0.3.5-py3-none-any.whl.

File metadata

  • Download URL: rec_pangu-0.3.5-py3-none-any.whl
  • Upload date:
  • Size: 106.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.7.10

File hashes

Hashes for rec_pangu-0.3.5-py3-none-any.whl
Algorithm Hash digest
SHA256 72c86f9467edfb3bf9bee13d61f9ee6ac582c500a0453b387dad2f490ebb0e2d
MD5 60e481d728c168e554087424ac5b3182
BLAKE2b-256 005e2cae751e2edf336889eb02d9cde5a68644b2b6921e82e5c235f90dd1f290

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page