Skip to main content

An easy-t0-use wrapper library for the Transformers library.

Project description

Simple Transformers

This library is based on the Pytorch-Transformers library by HuggingFace. Using this library, you can quickly train and evaluate Transformer models.

Visit the Github page for documentation and code.

Please refer to this Medium article for further information on how this project works.

Please note that the documentation is still being written.

Table of contents

Setup

With Conda

  1. Install Anaconda or Miniconda Package Manager from here

  2. Create a new virtual environment and install packages.
    conda create -n transformers python pandas tqdm
    conda activate transformers
    If using cuda:
        conda install pytorch cudatoolkit=10.0 -c pytorch
    else:
        conda install pytorch cpuonly -c pytorch
    conda install -c anaconda scipy
    conda install -c anaconda scikit-learn
    pip install transformers
    pip install tensorboardx

  3. Install simpletransformers.
    pip install simpletransformers

Usage

Minimal Start

from simpletransformers.model import TransformerModel
import pandas as pd


# Train and Evaluation data needs to be in a Pandas Dataframe of two columns. The first column is the text with type str, and the second column in the label with type int.
train_data = [['Example sentence belonging to class 1', 1], ['Example sentence belonging to class 0', 0]]
train_df = pd.DataFrame(train_data)

eval_data = [['Example eval sentence belonging to class 1', 1], ['Example eval sentence belonging to class 0', 0]]
eval_df = pd.DataFrame(eval_data)

# Create a TransformerModel
model = TransformerModel('roberta', 'roberta-base')

# Train the model
model.train_model(train_df)

# Evaluate the model
result, model_outputs, wrong_predictions = model.eval_model(eval_df)

The default args used are given below. Any of these can be overridden by passing a dict containing the corresponding key: value pairs to the the init method of TransformerModel.

self.args = {
   'data_dir': 'data/',
   'model_type':  'roberta',
   'model_name': 'roberta-base',
   'output_dir': 'outputs/',

   'fp16': True,
   'fp16_opt_level': 'O1',
   'max_seq_length': 128,
   'train_batch_size': 8,
   'eval_batch_size': 8,
   'gradient_accumulation_steps': 1,
   'num_train_epochs': 1,
   'weight_decay': 0,
   'learning_rate': 4e-5,
   'adam_epsilon': 1e-8,
   'warmup_ratio': 0.06,
   'warmup_steps': 0,
   'max_grad_norm': 1.0,

   'logging_steps': 50,
   'evaluate_during_training': False,
   'save_steps': 2000,
   'eval_all_checkpoints': True,
   'use_tensorboard': True,

   'overwrite_output_dir': False,
   'reprocess_input_data': False,
}

Explanation of each parameter to be added to docs soon

Current Pretrained Models

The table below shows the currently available model types and their models. You can use any of these by setting the model_type and model_name in the args dictionary. For more information about pretrained models, see HuggingFace docs.

Architecture Model Type Model Name Details
BERT bert bert-base-uncased 12-layer, 768-hidden, 12-heads, 110M parameters.
Trained on lower-cased English text.
BERT bert bert-large-uncased 24-layer, 1024-hidden, 16-heads, 340M parameters.
Trained on lower-cased English text.
BERT bert bert-base-cased 12-layer, 768-hidden, 12-heads, 110M parameters.
Trained on cased English text.
BERT bert bert-large-cased 24-layer, 1024-hidden, 16-heads, 340M parameters.
Trained on cased English text.
BERT bert bert-base-multilingual-uncased (Original, not recommended) 12-layer, 768-hidden, 12-heads, 110M parameters.
Trained on lower-cased text in the top 102 languages with the largest Wikipedias
BERT bert bert-base-multilingual-cased (New, recommended) 12-layer, 768-hidden, 12-heads, 110M parameters.
Trained on cased text in the top 104 languages with the largest Wikipedias
BERT bert bert-base-chinese 12-layer, 768-hidden, 12-heads, 110M parameters.
Trained on cased Chinese Simplified and Traditional text.
BERT bert bert-base-german-cased 12-layer, 768-hidden, 12-heads, 110M parameters.
Trained on cased German text by Deepset.ai
BERT bert bert-large-uncased-whole-word-masking 24-layer, 1024-hidden, 16-heads, 340M parameters.
Trained on lower-cased English text using Whole-Word-Masking
BERT bert bert-large-cased-whole-word-masking 24-layer, 1024-hidden, 16-heads, 340M parameters.
Trained on cased English text using Whole-Word-Masking
BERT bert bert-large-uncased-whole-word-masking-finetuned-squad 24-layer, 1024-hidden, 16-heads, 340M parameters.
The bert-large-uncased-whole-word-masking model fine-tuned on SQuAD
BERT bert bert-large-cased-whole-word-masking-finetuned-squad 24-layer, 1024-hidden, 16-heads, 340M parameters
The bert-large-cased-whole-word-masking model fine-tuned on SQuAD
BERT bert bert-base-cased-finetuned-mrpc 12-layer, 768-hidden, 12-heads, 110M parameters.
The bert-base-cased model fine-tuned on MRPC
XLNet xlnet xlnet-base-cased 12-layer, 768-hidden, 12-heads, 110M parameters.
XLNet English model
XLNet xlnet xlnet-large-cased 24-layer, 1024-hidden, 16-heads, 340M parameters.
XLNet Large English model
XLM xlm xlm-mlm-en-2048 12-layer, 2048-hidden, 16-heads
XLM English model
XLM xlm xlm-mlm-ende-1024 6-layer, 1024-hidden, 8-heads
XLM English-German Multi-language model
XLM xlm xlm-mlm-enfr-1024 6-layer, 1024-hidden, 8-heads
XLM English-French Multi-language model
XLM xlm xlm-mlm-enro-1024 6-layer, 1024-hidden, 8-heads
XLM English-Romanian Multi-language model
XLM xlm xlm-mlm-xnli15-1024 12-layer, 1024-hidden, 8-heads
XLM Model pre-trained with MLM on the 15 XNLI languages
XLM xlm xlm-mlm-tlm-xnli15-1024 12-layer, 1024-hidden, 8-heads
XLM Model pre-trained with MLM + TLM on the 15 XNLI languages
XLM xlm xlm-clm-enfr-1024 12-layer, 1024-hidden, 8-heads
XLM English model trained with CLM (Causal Language Modeling)
XLM xlm xlm-clm-ende-1024 6-layer, 1024-hidden, 8-heads
XLM English-German Multi-language model trained with CLM (Causal Language Modeling)
RoBERTa roberta roberta-base 125M parameters
RoBERTa using the BERT-base architecture
RoBERTa roberta roberta-large 24-layer, 1024-hidden, 16-heads, 355M parameters
RoBERTa using the BERT-large architecture
RoBERTa roberta roberta-large-mnli 24-layer, 1024-hidden, 16-heads, 355M parameters
roberta-large fine-tuned on MNLI.

Acknowledgements

None of this would have been possible without the hard work by the HuggingFace team in developing the Pytorch-Transformers library.

Project details


Release history Release notifications | RSS feed

This version

0.1.6

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

simpletransformers-0.1.6.tar.gz (13.0 kB view details)

Uploaded Source

Built Distribution

simpletransformers-0.1.6-py3-none-any.whl (25.0 kB view details)

Uploaded Python 3

File details

Details for the file simpletransformers-0.1.6.tar.gz.

File metadata

  • Download URL: simpletransformers-0.1.6.tar.gz
  • Upload date:
  • Size: 13.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.36.1 CPython/3.7.4

File hashes

Hashes for simpletransformers-0.1.6.tar.gz
Algorithm Hash digest
SHA256 4b20b307e19e56acf18e6f54c5288ed81ad4f5249c81e7c47f73cde4cfd733cf
MD5 6b14ca37d286014061cf48480b7c0304
BLAKE2b-256 1412b10abf8da24bc152e5739b51f50aaf273ac6b1df2205c50eeebcf43f7f48

See more details on using hashes here.

File details

Details for the file simpletransformers-0.1.6-py3-none-any.whl.

File metadata

  • Download URL: simpletransformers-0.1.6-py3-none-any.whl
  • Upload date:
  • Size: 25.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.36.1 CPython/3.7.4

File hashes

Hashes for simpletransformers-0.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 f99e96a2a310be93c65f2f6232a6b917d6ec0074f01027409a5ff77dd2aa7fc0
MD5 bc756bcefe411ca4abcadbbf6045ca61
BLAKE2b-256 737a12846331031be3ea8457cac0db190f2e29074ce9fe2bf567d66a2f95dfe9

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page