Skip to main content

Predicting product categories.

Project description

Deep Learning Model for Product Category Prediction

Product category prediction model built with:

and trained using Amazon product data.

This library supports

  • Predicting categories using the pretrained model.
  • Training from scratch, with a transformers model as the starting point.
  • Transfer learning from the pretrained model.

Pretrained model

The pretrained model is trained using product category and title in the metadata Amazon product data. Each product can have multiple categories. We sample 500K products (85% for train; 15% for validation) to train the model, which resulted in ~1900 categories. We use pytorch-lightning to train a multilabel classification model with the pretrained distilbert-base-cased model from huggingface/transformers as the starting point. This library supports

  1. directly using this pretrained model to predict the ~1900 categories from an input product title or description;
  2. using this pretrained model as a starting point to do transfer learning and train a category prediction model on your own categories, as long as you provide training data in the format described below.

You can also train a model from scratch without using this pretrained model, but instead with a transformers model as the starting point.

Download Pretrained Model

Download the pretrained model to data folder:

wget https://github.com/yang-zhang/product_category/releases/download/v0.0.1/transformer_20210307D3.ckpt -P data

Installation

pip install product-category

Predict with Pre-trained Model

python product_category/predict.py -h
usage: predict.py [-h] -t TEXT [--trained_model_path TRAINED_MODEL_PATH]
                  [--i2cat_path I2CAT_PATH] [--tokenizer_name TOKENIZER_NAME]
                  [--topn TOPN]

optional arguments:
  -h, --help            show this help message and exit
  -t TEXT, --text TEXT  Product info text to predict.
  --trained_model_path TRAINED_MODEL_PATH
                        Model used to predict.
  --i2cat_path I2CAT_PATH
                        File name for the ordered list of categories. Each
                        line for one category.
  --tokenizer_name TOKENIZER_NAME
                        Tokenizer name.
  --topn TOPN           Number of top predicted categories to display.

For example:

python predict.py -t "Lykmera Famous TikTok Leggings, High Waist Yoga Pants for Women, Booty Bubble Butt Lifting Workout Running Tights"

Sports & Outdoors: 0.997
Sports & Fitness: 0.994
Exercise & Fitness: 0.980
Clothing: 0.961
Yoga: 0.905

Training Data Format

Training data file should be csv with 3 columns: category (categories separated by '|'), title (str), is_validation (0 or 1). Similar to data/example_data.csv.

category,title,is_validation
Sports & Outdoors|Outdoor Recreation|Cycling|Clothing|Men|Shorts,Louis Garneau Men's Neo Power Motion Bike Shorts,1
"Clothing, Shoes & Jewelry|Novelty & More|Clothing|Novelty",Nirvana Men's Short Sleeve Many Smiles T-Shirt Shirt,0
Grocery & Gourmet Food|Snack Foods|Chips & Crisps|Tortilla,Doritos Tapatio Salsa Picante Hot Sauce Flavor Chips 7 5/8 oz Bag (Pack of 1),0
"Clothing, Shoes & Jewelry|Women|Shoes|Boots|Synthetic|Synthetic sole|Vegan Friendly",SODA Womens Dome-H Boot,1
Sports & Outdoors|Outdoor Recreation|Camping & Hiking,Folding Pot Stabilizer,0

Training

Below are a subset of options for training.py. Run python train.py -h to see full help list, which includes more options for pytorch-lightning functionalities.

python train.py -h
usage: train.py [-h] [--model_name_or_path MODEL_NAME_OR_PATH]
                [--transfer_learn] [--trained_model_path TRAINED_MODEL_PATH]
                [--data_file_path DATA_FILE_PATH] [--freeze_bert]
                [--max_seq_length MAX_SEQ_LENGTH]
                [--min_products_for_category MIN_PRODUCTS_FOR_CATEGORY]
                [--train_batch_size TRAIN_BATCH_SIZE]
                [--val_batch_size VAL_BATCH_SIZE]
                [--dataloader_num_workers DATALOADER_NUM_WORKERS]
                [--pin_memory] [--logger [LOGGER]]
                [--learning_rate LEARNING_RATE] 

optional arguments:
  -h, --help            show this help message and exit
  --model_name_or_path MODEL_NAME_OR_PATH
                        Path to pretrained model or model identifier from
                        huggingface.co/models.
  --transfer_learn      Wether to use transfer learning based on a pretrained
                        model.
  --trained_model_path TRAINED_MODEL_PATH
                        Model used to predict.
  --data_file_path DATA_FILE_PATH
                        Path to training data file. Data file should be csv
                        with 3 columns: category (categories separated by
                        '|'),title (str),is_validation (0 or 1). e.g.: Sports
                        & Outdoors|Outdoor
                        Recreation|Cycling|Clothing|Men|Shorts,Louis Garneau
                        Men's Neo Power Motion Bike Shorts,1
  --freeze_bert         Whether to freeze the pretrained model.
  --max_seq_length MAX_SEQ_LENGTH
                        The maximum total input sequence length after
                        tokenization. Sequences longer than this will be
                        truncated, sequences shorter will be padded.
  --min_products_for_category MIN_PRODUCTS_FOR_CATEGORY
                        Minimum number of products for a category to be
                        considered in the model.
  --train_batch_size TRAIN_BATCH_SIZE
                        How many samples per batch to load for train
                        dataloader.
  --val_batch_size VAL_BATCH_SIZE
                        How many samples per batch to load for validation
                        dataloader.
  --dataloader_num_workers DATALOADER_NUM_WORKERS
                        How many subprocesses to use for data loading. 0 means
                        that the data will be loaded in the main process.
  --pin_memory          Wether to use pin_memory in pytorch dataloader. If
                        True, the data loader will copy Tensors into CUDA
                        pinned memory before returning them.
  --learning_rate LEARNING_RATE
                        Learning Rate

Training from Scratch

Training from scratch, with a transformers model as the starting point.

For example:

python train.py --data_file_path ../data/sample_data.csv

Transfer Learning from Pre-trained Model

Transfer learning from the pretrained model.

For example:

python train.py --transfer_learn --data_file_path ../data/sample_data.csv

Useful Pytorch-Lightning Options

To run with GPU:

python train.py --transfer_learn --data_file_path ../data/sample_data.csv --gpus=1

To train only a classification head with the transformer backbone frozen:

python train.py --transfer_learn --data_file_path ../data/sample_data.csv --freeze_bert

To run with GPU, pin_memory for dataloader, and limiting maximum training epochs:

python train.py --transfer_learn --data_file_path ../data/sample_data.csv --gpus=1 --pin_memory --max_epochs=100

Note

The pretrained model is trained using Amazon product data, which is for research purpose. Therefore, the pretrained model should also be used for research purposes.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

product_category-0.0.2.tar.gz (8.8 kB view details)

Uploaded Source

Built Distribution

product_category-0.0.2-py2.py3-none-any.whl (8.5 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file product_category-0.0.2.tar.gz.

File metadata

  • Download URL: product_category-0.0.2.tar.gz
  • Upload date:
  • Size: 8.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.24.0 setuptools/50.3.1.post20201107 requests-toolbelt/0.9.1 tqdm/4.50.2 CPython/3.7.9

File hashes

Hashes for product_category-0.0.2.tar.gz
Algorithm Hash digest
SHA256 93728d00d53104c44421ea44ed317215af9e197179e762143cc1a5189bac1ce6
MD5 a8075ab37a50b69aadde05348a94bb02
BLAKE2b-256 5e0d96d1e6762b0e8dcb03d7f0a5505238a9574089e60586ad2298607595e690

See more details on using hashes here.

File details

Details for the file product_category-0.0.2-py2.py3-none-any.whl.

File metadata

  • Download URL: product_category-0.0.2-py2.py3-none-any.whl
  • Upload date:
  • Size: 8.5 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.24.0 setuptools/50.3.1.post20201107 requests-toolbelt/0.9.1 tqdm/4.50.2 CPython/3.7.9

File hashes

Hashes for product_category-0.0.2-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 bc25dc1f640108f65ae0e5adffa54a0e51ac42abed3505bfbb3ae134c5e6c879
MD5 f9c6d9fb8522d45301e64580e69b94b3
BLAKE2b-256 a8d6fcd2966a917fd1db830ff6c96d590045b72aa173656811beddfe3b049068

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page