Skip to main content

A Package to Easily Train Bert-Like Models for Text Classification

Project description

The Augmented Social Scientist

This package makes it extremely easy to train BERT-like models for text classifications.

For more information on the method and for some use cases from social sciences, see "The Augmented Social Scientist: Using Sequential Transfer Learning to Annotate Millions of Texts with Human-Level Accuracy" published on Sociological Methods & Research by Salomé Do, Étienne Ollion and Rubing Shen.

To install the package

pip install augmentedsocialscientist

To train a BERT model

Import

from augmentedsocialscientist.models import Bert

bert = Bert()  #instanciation

Functions

The class Bert contains 3 main methods:

  • encode() to preprocess the data;
  • run_training() to train, validate and save a model;
  • predict_with_model() to make predictions with a saved model.

Example

import pandas as pd
import numpy as np

from augmentedsocialscientist.models import Bert

bert = Bert() #instanciation


# 1. Training 
## Load training and validation data
train_data = pd.read_csv('https://raw.githubusercontent.com/rubingshen/augmented_tutorial/main/clickbait/clickbait_train.csv')
test_data = pd.read_csv('https://raw.githubusercontent.com/rubingshen/augmented_tutorial/main/clickbait/clickbait_test.csv')

## Preprocess the training and validation data
train_loader = bert.encode(
    train.text.values,      #list of texts
    train.label.values      #list of labels
    )    
test_loader = bert.encode(
    test.text.values,       #list of texts
    test.label.values       #list of labels
    )      

## Train, validate and save a model
scores = bert.run_training(
    train_loader,             #training dataloader
    test_loader,              #test dataloader
    lr=5e-5,                  #learning rate
    n_epochs=3,               #number of epochs
    random_state=42,          #random state (for replicability)
    save_model_as='clickbait' #name of model to save as
)
# this trained model will be saved at ./models/clickbait
# the output "scores" contains precision, recall, f1-score and support for each classification category, assessed against the provided test set

# 2. Prediction on unlabeled data

## Load prediction data
pred_data = pd.read_csv('https://raw.githubusercontent.com/rubingshen/augmented_tutorial/main/clickbait/clickbait_pred.csv')

## Preprocess the prediction data
pred_loader = bert.encode(pred_data.text.values) #input a list of unlabeld texts

## Prediction with the saved trained model
pred = bert.predict_with_model(
    pred_loader, 
    model_path='./models/clickbait'
    )
# the output "pred" is a ndarray containing probabilities for each text (row) of belonging to each category (column)

## Compute the predicted label as the one with the highest probability

pred_data['pred_label'] = np.argmax(pred, axis=1)
pred_data['pred_proba'] = np.max(pred, axis=1)

Tutorial

Check here for an interactive tutorial on Google Colab.

Languages supported

Bert is a pre-trained language model for the English language. The module augmentedsocialscientist.models also contains models for other languages:

  • ArabicBert for Arabic;
  • Camembert or Flaubert for French;
  • ChineseBert for Chinese;
  • GermanBert for German;
  • HindiBert for Hindi;
  • ItalianBert for Italian;
  • PortugueseBert for Portuguese;
  • RussianBert for Russian;
  • SpanishBert for Spanish;
  • SwedishBert for Swedish;
  • XLMRoberta which is a multi-lingual model supporting 100 languages.

To use them, just import the corresponding model and instanciate it as in the previous example.

For example, to use the French language model Camembert:

from augmentedsocialscientist.models import Camembert

bert = Camembert()  #instanciation

You can then use the functions bert.encode(), bert.run_training(), bert.predict_with_model() as in the previous example.

To use a custom model from Hugging Face

The package also allows you to use other BERT-like models from Hugging Face, by changing the argument model_name to the desired model name when instanciating the class Bert.

For example, to use the Danish BERT model DJSammy/bert-base-danish-uncased_BotXO-ai from Hugging Face:

from augmentedsocialscientist.models import Bert

bert = Bert(model_name="DJSammy/bert-base-danish-cased_BotXO-ai")

To use a custom torch.Device

By default, the package automatically detects the presence of a GPU and uses it to accelerate computation. You can also set your own device, by providing a torch.Device object to the parameter device when instanciating Bert.

from augmentedsocialscientist.models import Bert

bert = Bert(device=...)  #set your own device

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

augmentedsocialscientist-3.0.0.tar.gz (11.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

augmentedsocialscientist-3.0.0-py3-none-any.whl (11.5 kB view details)

Uploaded Python 3

File details

Details for the file augmentedsocialscientist-3.0.0.tar.gz.

File metadata

  • Download URL: augmentedsocialscientist-3.0.0.tar.gz
  • Upload date:
  • Size: 11.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.11

File hashes

Hashes for augmentedsocialscientist-3.0.0.tar.gz
Algorithm Hash digest
SHA256 514f47f876cadddf3f8ba79de3054ca0c34bf5e6903ec87599911a853029e79a
MD5 8f3592f6d4b72ee6369cefb577b821b7
BLAKE2b-256 4b8e29be1627f03ccd63bf2a8ed31b1064cf21f30a9436936555942d334a8f94

See more details on using hashes here.

File details

Details for the file augmentedsocialscientist-3.0.0-py3-none-any.whl.

File metadata

File hashes

Hashes for augmentedsocialscientist-3.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 fed68db3a1b6cb62f7e1e0486273287deacbbe8ecd33bbd7e7000e585d96dc77
MD5 b6cb0df609021792bb53e6625ffb4bc3
BLAKE2b-256 ad947f8fc54e26c18bb17d3b4f47311eb9b96bddae347384e758eb06eb46c5eb

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page