Skip to main content

CNN+BiLSTM hybrid architecture for product classification

Project description

CNN-BiLSTM Classifier

A hybrid CNN+BiLSTM architecture for product classification with attention mechanism, implemented in PyTorch.

Features

  • Multi-kernel CNN: Captures local n-gram features with different kernel sizes (2, 3, 4, 5)
  • Bidirectional LSTM: Captures sequential dependencies with soft attention mechanism
  • Ad-hoc Features: Extracts statistical and linguistic features from text
  • Word2Vec Embeddings: Pre-trained word embeddings for better text representation
  • Easy Integration: Simple API for training and inference
  • HuggingFace Hub Support: Load and save models to/from HuggingFace Hub

Installation

Install from PyPI:

pip install cnn-bilstm-classifier

Or install from source:

git clone https://github.com/turgutguvercin/cnn-bilstm-classifier.git
cd cnn-bilstm-classifier
pip install -e .

Quick Start

Loading a Pre-trained Model

from product-classifier import CNNBiLSTMInference

# Load from HuggingFace Hub
model = CNNBiLSTMInference.from_pretrained("turgutguvercin/product-classifier-v1")

# Make predictions
predictions = model.predict([
    "Yataş Bedding BAMBU Yorgan (%20 Bambu) 300 Gr.",
    "Arji Ev ve Ofis Çalışma Sandalyesi Bilgisayar Koltuğu"
], top_k=3)

for i, pred in enumerate(predictions):
    print(f"Text {i+1} predictions:")
    for label, score in pred:
        print(f"  → {label}: {score:.4f}")

Training a New Model

from product_classifier import CNNBiLSTMTrainer
import pandas as pd

# Load your data
df = pd.read_csv("your_data.csv")

# Initialize trainer
trainer = CNNBiLSTMTrainer()

# Train the model
results = trainer.train(
    df,
    text_column="product_name",
    label_column="category",
    config={
        'batch_size': 128,
        'epochs': 20,
        'embedding_dim': 512,
        'cnn_filters': 128,
        'lstm_hidden': 256
    }
)

# Save the model
trainer.save_model("./my_model")

Command Line Interface

Train a model:

cnn-bilstm-train --data data.csv --text-column name --label-column category --output-dir ./model

Make predictions:

cnn-bilstm-predict --model-path ./model --text "Product name to classify"

Model Architecture

The CNN+BiLSTM classifier combines three main components:

  1. Multi-kernel CNN Branch:

    • Convolutional layers with kernel sizes [2, 3, 4, 5]
    • Max pooling over sequence dimension
    • Captures local n-gram patterns
  2. BiLSTM Branch:

    • Bidirectional LSTM layers
    • Soft attention mechanism
    • Captures long-range dependencies
  3. Ad-hoc Features Branch:

    • Statistical features (length, character ratios)
    • Linguistic features (word count, symbol count)
    • Word length histograms

All branches are concatenated and fed to a final classification layer.

API Reference

CNNBiLSTMInference

Main class for model inference.

Methods

  • from_pretrained(repo_id): Load model from HuggingFace Hub
  • from_local(model_dir): Load model from local directory
  • predict(texts, top_k=1): Predict categories for input texts
  • get_model_info(): Get model architecture information

Example

# Load model
model = CNNBiLSTMInference.from_pretrained("username/model-name")

# Single prediction
prediction = model.predict("Product title", top_k=3)

# Batch prediction
predictions = model.predict(["Title 1", "Title 2"], top_k=1)

# Model info
info = model.get_model_info()
print(f"Vocabulary size: {info['vocab_size']}")

CNNBiLSTMTrainer

Class for training new models.

Methods

  • train(df, text_column, label_column, config): Train a new model
  • save_model(output_dir): Save trained model
  • save_to_hub(repo_name): Upload model to HuggingFace Hub

AdHocFeatureExtractor

Extracts statistical and linguistic features from text.

Features Extracted

  1. Title length
  2. Uppercase character rate
  3. Alphabetic character rate
  4. Digit character rate
  5. Space count and rate
  6. Word count
  7. Maximum word length
  8. Unique word rate
  9. Symbol count
  10. Word length histogram (5 bins)

Configuration

Default training configuration:

config = {
    'batch_size': 128,
    'epochs': 20,
    'learning_rate': 0.001,
    'embedding_dim': 512,
    'cnn_filters': 128,
    'lstm_hidden': 256,
    'max_length': 50
}

Requirements

  • Python >= 3.8
  • PyTorch >= 1.9.0
  • NumPy >= 1.21.0
  • SafeTensors >= 0.3.0
  • HuggingFace Hub >= 0.15.0
  • Gensim >= 4.0.0
  • scikit-learn >= 1.0.0
  • tqdm >= 4.60.0

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

  1. Fork the repository
  2. Create your feature branch (git checkout -b feature/AmazingFeature)
  3. Commit your changes (git commit -m 'Add some AmazingFeature')
  4. Push to the branch (git push origin feature/AmazingFeature)
  5. Open a Pull Request

License

This project is licensed under the MIT License - see the LICENSE file for details.

Citation

If you use this work in your research, please cite:

@software{product_classifier,
    title={CNN-BiLSTM Classifier: Hybrid Architecture for Product Classification},
    author={Turgut Guvercin},
    year={2025},
    url={https://github.com/turgutguvercin/cnn-bilstm-classifier},
    note={Implementation based on the methodology from Suzuki et al. (2018)}
}

@inproceedings{suzuki2018cnn,
    title={Convolutional Neural Network and Bidirectional LSTM Based Taxonomy Classification Using External Dataset at SIGIR eCom Data Challenge},
    author={Suzuki, Shogo D. and Iseki, Yohei and Shiino, Hiroaki and Zhang, Hongwei and Iwamoto, Aya and Takahashi, Fumihiko},
    booktitle={Proceedings of ACM SIGIR Workshop on eCommerce (SIGIR 2018 eCom Data Challenge)},
    year={2018},
    month={July},
    address={Ann Arbor, Michigan, USA},
    publisher={ACM},
    pages={1--5},
    url={https://sigir-ecom.github.io/ecom2018/ecom18DCPapers/ecom18DC_paper_1.pdf},
    note={Original paper describing CNN+BiLSTM architecture for product taxonomy classification}
}

Acknowledgments

  • Based on "Convolutional Neural Network and Bidirectional LSTM Based Taxonomy Classification"
  • Built with PyTorch and HuggingFace ecosystem
  • Special thanks to the open-source community

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

product_classifier-0.1.0.tar.gz (33.7 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

product_classifier-0.1.0-py3-none-any.whl (38.8 MB view details)

Uploaded Python 3

File details

Details for the file product_classifier-0.1.0.tar.gz.

File metadata

  • Download URL: product_classifier-0.1.0.tar.gz
  • Upload date:
  • Size: 33.7 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.11

File hashes

Hashes for product_classifier-0.1.0.tar.gz
Algorithm Hash digest
SHA256 68a60d4ba7f263c0f9dcf4f3c4fffb88fffe2ece7b7a6c3e196437960997c75e
MD5 1dd9968debdeed22971520edcd5a34f6
BLAKE2b-256 566b14a7ca1cac4e83f6ce48a9770a6a1047b57a7f3f42dc09ad2ee7314bd9b2

See more details on using hashes here.

File details

Details for the file product_classifier-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for product_classifier-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 c76ddca2bff3960e62129ff35e017e565eab94b932ce7307a3cb278b59e1aec8
MD5 43eaab2cb6eee0fdba4673253b111fc3
BLAKE2b-256 0dcfbcf20752781015b2f4ca304af537f6ad865f935dfd75ac7366dbcebab7a7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page