Skip to main content

easy SASRec

Project description

ezSASRec

Documentation

https://ezsasrec.netlify.app

References

repos

  1. kang205 SASRec
  2. nnkkmto/SASRec-tf2
  3. microsoft recommenders

papers

  1. Self-Attentive Sequential Recommendation
  2. A Case Study on Sampling Strategies for Evaluating Neural Sequential Item Recommendation Models

QuickStart

example data source: https://www.kaggle.com/datasets/rounakbanik/the-movies-dataset

import pandas as pd 
import pickle
from sasrec.util import filter_k_core, SASRecDataSet, load_model
from sasrec.model import SASREC
from sasrec.sampler import WarpSampler
import multiprocessing

Preprocessing

path = 'your path'
df = pd.read_csv('ratings.csv')
df = df.rename({'userId':'userID','movieId':'itemID','timestamp':'time'},axis=1)\
       .sort_values(by=['userID','time'])\
       .drop(['rating','time'],axis=1)\
       .reset_index(drop=True)
df.head()
userID itemID
0 1 2762
1 1 54503
2 1 112552
3 1 96821
4 1 5577
# filter data
# every user and item will appear more than 6 times in filtered_df

filtered_df = filter_k_core(df, 7)
Original: 270896 users and 45115 items
Final: 243377 users and 24068 items
# make maps (encoder)

user_set, item_set = set(filtered_df['userID'].unique()), set(filtered_df['itemID'].unique())
user_map = dict()
item_map = dict()
for u, user in enumerate(user_set):
    user_map[user] = u+1
for i, item in enumerate(item_set):
    item_map[item] = i+1

maps = (user_map, item_map)   
# Encode filtered_df

filtered_df["userID"] = filtered_df["userID"].apply(lambda x: user_map[x])
filtered_df["itemID"] = filtered_df["itemID"].apply(lambda x: item_map[x])
# save data and maps

# save sasrec data    
filtered_df.to_csv('sasrec_data.txt', sep="\t", header=False, index=False)

# save maps
with open('maps.pkl','wb') as f:
    pickle.dump(maps, f)

Load data and Train model

# load data

data = SASRecDataSet('sasrec_data.txt')
data.split() # train, val, test split
              # the last interactions of each user is used for test
              # the last but one will be used for validation
              # others will be used for train
# make model and warmsampler for batch training

max_len = 100
hidden_units = 128
batch_size = 2048

model = SASREC(
    item_num=data.itemnum,
    seq_max_len=max_len,
    num_blocks=2,
    embedding_dim=hidden_units,
    attention_dim=hidden_units,
    attention_num_heads=2,
    dropout_rate=0.2,
    conv_dims = [hidden_units, hidden_units],
    l2_reg=0.00001
)

sampler = WarpSampler(data.user_train, data.usernum, data.itemnum, batch_size=batch_size, maxlen=max_len, n_workers=multiprocessing.cpu_count())
# train model

model.train(
          data,
          sampler,
          num_epochs=3, 
          batch_size=batch_size, 
          lr=0.001, 
          val_epoch=1,
          val_target_user_n=1000, 
          target_item_n=-1,
          auto_save=True,
          path = path,
          exp_name='exp_example',
        )
epoch 1 / 3 -----------------------------

Evaluating...    

epoch: 1, test (NDCG@10: 0.04607630127474612, HR@10: 0.097)
best score model updated and saved


epoch 2 / 3 -----------------------------

Evaluating...    

epoch: 2, test (NDCG@10: 0.060855185638025944, HR@10: 0.118)
best score model updated and saved


epoch 3 / 3 -----------------------------

Evaluating...   

epoch: 3, test (NDCG@10: 0.07027207563856912, HR@10: 0.139)
best score model updated and saved

Predict

# load trained model

model = load_model(path,'exp_example')

get score

# get user-item score

# make inv_user_map
inv_user_map = {v: k for k, v in user_map.items()}

# sample target_user
model.sample_val_users(data, 100)
encoded_users = model.val_users

# get scores
score = model.get_user_item_score(data,
                          [inv_user_map[u] for u in encoded_users], # user_list containing raw(not-encoded) userID 
                          [1,2,3], # item_list containing raw(not-encoded) itemID
                          user_map,
                          item_map,   
                          batch_size=10
                        )
100%|██████████| 10/10 [00:00<00:00, 29.67batch/s]
score.head()
user_id 1 2 3
0 1525 5.596944 4.241653 3.804743
1 1756 4.535607 2.694459 0.858440
2 2408 5.883061 4.655960 4.691791
3 2462 5.084695 2.942075 2.773376
4 3341 5.532438 4.348150 4.073740

get recommendation

# get top N recommendation 

reco = model.recommend_item(data,
                            user_map,
                            [inv_user_map[u] for u in encoded_users],
                            is_test=True,
                            top_n=5)
100%|██████████| 100/100 [00:04<00:00, 21.10it/s]
# returned tuple contains topN recommendations for each user

reco
{1525: [(456, 6.0680223),
  (355, 6.033769),
  (379, 5.9833336),
  (591, 5.9718275),
  (776, 5.8978705)],
 1756: [(7088, 5.735977),
  (15544, 5.5946136),
  (5904, 5.500249),
  (355, 5.492655),
  (22149, 5.4117346)],
 2408: [(456, 5.976555),
  (328, 5.8824606),
  (588, 5.8614006),
  (264, 5.7114534),
  (299, 5.649914)],
 2462: [(259, 6.3445344),
  (591, 6.2664876),
  (295, 6.105361),
  (355, 6.0698805),
  (1201, 5.8477645)],
 3341: [(110, 5.510764),
  (1, 5.4927354),
  (259, 5.4851904),
  (161, 5.467624),
  (208, 5.2486935)], ...}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ezSASRec-1.0.1.tar.gz (18.2 kB view details)

Uploaded Source

Built Distribution

ezSASRec-1.0.1-py3-none-any.whl (18.8 kB view details)

Uploaded Python 3

File details

Details for the file ezSASRec-1.0.1.tar.gz.

File metadata

  • Download URL: ezSASRec-1.0.1.tar.gz
  • Upload date:
  • Size: 18.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.0

File hashes

Hashes for ezSASRec-1.0.1.tar.gz
Algorithm Hash digest
SHA256 87b9e972b53ba72748e754a22a8501cf80aa7d5c3bb1916f381f55c6bf372637
MD5 824e3cfe496e1c5de28efd8f603f7790
BLAKE2b-256 e52f6bec44940ec6bcdd17e4489c3db8cd3c7c0d667eb411bb0da1fd1cb73c08

See more details on using hashes here.

File details

Details for the file ezSASRec-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: ezSASRec-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 18.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.0

File hashes

Hashes for ezSASRec-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 40584777c6754b5812107e1f5c21600eaa6549a9991ad9c765e515a4c49a1d66
MD5 99db32f02735e0145e6096d9728575b5
BLAKE2b-256 9307bb4bf4d37fed6435bd68ab8e844c5f79c9061df30f8f31e2cf264dc4c9c1

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page