Skip to main content

A text classification toolkit

Project description

Bunruija

PyPI version

Bunruija is a text classification toolkit. Bunruija aims at enabling pre-processing, training and evaluation of text classification models with minimum coding effort. Bunruija is mainly focusing on Japanese though it is also applicable to other languages.

See example for understanding how bunruija is easy to use.

Features

  • Minimum requirements of coding: bunruija enables users to train and evaluate their models through command lines. Because all experimental settings are stored in a yaml file, users do not have to write codes.
  • Easy to compare neural-based model with non-neural-based model: because bunruija supports models based on scikit-learn and PyTorch in the same framework, users can easily compare classification accuracies and prediction times of neural- and non-neural-based models.
  • Easy to reproduce the training of a model: because all hyperparameters of a model are stored in a yaml file, it is easy to reproduce the model.

Install

pip install bunruija

Example configs

Example of sklearn.svm.SVC

data:
  label_column: category
  text_column: title
  args:
    path: data/jsonl

output_dir: models/svm-model

pipeline:
  - type: sklearn.feature_extraction.text.TfidfVectorizer
    args:
      tokenizer:
        type: bunruija.tokenizers.mecab_tokenizer.MeCabTokenizer
        args:
          lemmatize: true
          exclude_pos:
            - 助詞
            - 助動詞
      max_features: 10000
      min_df: 3
      ngram_range:
        - 1
        - 3
  - type: sklearn.svm.SVC
    args:
      verbose: false
      C: 10.

Example of BERT

data:
  label_column: category
  text_column: title
  args:
    path: data/jsonl

output_dir: models/transformer-model

pipeline:
  - type: bunruija.feature_extraction.sequence.SequenceVectorizer
    args:
      tokenizer:
        type: transformers.AutoTokenizer
        args:
          pretrained_model_name_or_path: cl-tohoku/bert-base-japanese
  - type: bunruija.classifiers.transformer.TransformerClassifier
    args:
      device: cpu
      pretrained_model_name_or_path: cl-tohoku/bert-base-japanese
      optimizer:
        type: torch.optim.AdamW
        args:
          lr: 3e-5
          weight_decay: 0.01
          betas:
            - 0.9
            - 0.999
      max_epochs: 3

CLI

# Training a classifier
bunruija-train -y config.yaml

# Evaluating the trained classifier
bunruija-evaluate -y config.yaml

Config

data

You can set data-related settings in data.

data:
  label_column: category
  text_column: title
  args:
    # Use local data in `data/jsonl`. In this path is assumed to contain data files such as train.jsonl, validation.jsonl and test.jsonl
    path: data/jsonl

    # If you want to use data on Hugging Face Hub, use the following args instead.
    # Data is from https://huggingface.co/datasets/shunk031/livedoor-news-corpus
    # path: shunk031/livedoor-news-corpus
    # random_state: 0
    # shuffle: true

data is loaded via datasets.load_dataset. So, you can load local data as well as data on Hugging Face Hub. When loading data, args are passed to load_dataset.

label_column and text_column are field names of label and text.

Format of csv:

category,sentence
sports,I like sports!
…

Format of json:

[{"category", "sports", "text": "I like sports!"}]

Format of jsonl:

{"category", "sports", "text": "I like suports!"}

pipeline

You can set pipeline of your model in pipeline section. It is a list of components that are used in your model.

For each component, type is a module path and args is arguments for the module. For instance, when you set the first component as follows, TfidfVectorizer is instanciated with given arguments, and then applied to data at first in your model.

  - type: sklearn.feature_extraction.text.TfidfVectorizer
    args:
      tokenizer:
        type: bunruija.tokenizers.mecab_tokenizer.MeCabTokenizer
        args:
          lemmatize: true
          exclude_pos:
            - 助詞
            - 助動詞
      max_features: 10000
      min_df: 3
      ngram_range:
        - 1
        - 3

Prediction using the trained classifier in Python code

After you trained a classification model, you can use that model for prediction as follows:

from bunruija import Predictor

predictor = Predictor.from_pretrained("output_dir")
while True:
    text = input("Input:")
    label: list[str] = predictor([text], return_label_type="str")
    print(label[0])

output_dir is a directory that is specified in output_dir in config.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bunruija-0.2.0.tar.gz (22.2 kB view details)

Uploaded Source

Built Distribution

bunruija-0.2.0-cp310-cp310-macosx_14_0_x86_64.whl (173.0 kB view details)

Uploaded CPython 3.10 macOS 14.0+ x86-64

File details

Details for the file bunruija-0.2.0.tar.gz.

File metadata

  • Download URL: bunruija-0.2.0.tar.gz
  • Upload date:
  • Size: 22.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.7.1 CPython/3.10.6 Darwin/23.2.0

File hashes

Hashes for bunruija-0.2.0.tar.gz
Algorithm Hash digest
SHA256 14de389989d915977a1f9e78ffe1dff48b28b3c2af794d2fff6a4e66a8385ec8
MD5 c26838718fc358032aa6f79c51e2c435
BLAKE2b-256 83e6fa4eb0cd095b9a4e128a6cb45e2b8e9392a7a551648933c1185b94e57ed7

See more details on using hashes here.

File details

Details for the file bunruija-0.2.0-cp310-cp310-macosx_14_0_x86_64.whl.

File metadata

File hashes

Hashes for bunruija-0.2.0-cp310-cp310-macosx_14_0_x86_64.whl
Algorithm Hash digest
SHA256 c51c3fbda155a8828c7583a7df15a62e1ac661699d50838157cc493751217413
MD5 206ba019ba54f4f53365ca1c5263ccad
BLAKE2b-256 bdedbc7a384636ef1e5a6f538048f549be16388a68a76216df8f84692cb98549

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page