Skip to main content

A transfer learning library for pre-trained transformers.

Project description

plamtral

PLaMTraL - A transfer learning library for pre-trained transformers.

Installation

Install plamtral with pip:

pip install plamtral

Features

Fine-tuning

Fine-tuning large pretrained language models on downstream tasks remains the de-facto learning paradigm in NLP. However, several fine tuning approaces exist other than the usual vanilla variant, which can be more effective or efficient. The fine tuning techniques provided in this package are:

  • BitFit - a sparse fine tuning method where only the bias terms of the model (or a subset of them) are being modified. Reference: https://arxiv.org/pdf/2106.10199.pdf.
  • Chain thaw - an approach that sequentially unfreezes and fine-tunes a single layer at a time. Reference: https://arxiv.org/pdf/1708.00524.pdf.
  • ULMFiT - an effective transfer learning method that introduces techniques (slanted triangular learning rate, disciminative fine-tuning, and gradual unfreezing) that are key for fine-tuning a language model. Reference: https://arxiv.org/pdf/1801.06146.pdf.
  • Vanilla fine tuning - the standard fine-tuning approach (fine-tune the whole model, fine-tune the last n layers, or fine-tune a specific layer).

Parameter efficient approaches

Since conventional fine-tuning approaches can become expensive as they often require the storage of a large number of parameters, recent work has proposed a variety of parameter-efficient transfer learning methods that only fine-tune a small number of (extra) parameters to attain strong performance. The parameter efficient techniques provided in this package use:

Usage/Examples

To use a GPT2 model with parallel adapters (for example):

from parameter_efficient.adapter import Model_with_parallel_adapter
from tl_lib.utils import load_dataloaders
from tl_lib.tl_train import train

# Load the GPT2 model with Parallel Adapters
model_obj = Model_with_parallel_adapter('GPT2')
# Create the train, validation and test dataloaders from the dataset file
train_loader, val_loader, test_loader = load_dataloaders('GPT2', dataset_path='path/to/dataset_file')
# Train the model
train(model_obj, train_loader, val_loader, verbose = True, model_save_name = 'path/to/model')

Requirements

  • torch 1.12.1
  • tqdm 4.64.1
  • transformers 4.24.0
  • nltk 3.7
  • torchmetrics

Authors

@Vibhu04

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

plamtral-0.0.9.tar.gz (15.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

plamtral-0.0.9-py3-none-any.whl (23.8 kB view details)

Uploaded Python 3

File details

Details for the file plamtral-0.0.9.tar.gz.

File metadata

  • Download URL: plamtral-0.0.9.tar.gz
  • Upload date:
  • Size: 15.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.7.15

File hashes

Hashes for plamtral-0.0.9.tar.gz
Algorithm Hash digest
SHA256 42c7f199047fdf6e9276e5e878a40a2e0c6fd92fd1468c6dc22fdc11bb4ec6ce
MD5 f4c8f8905c01142e1cf4cba111be65b6
BLAKE2b-256 28c77442f1b113adb898dea5bbc08098f321c32bf6b739c85cf91c8b4aa4083e

See more details on using hashes here.

File details

Details for the file plamtral-0.0.9-py3-none-any.whl.

File metadata

  • Download URL: plamtral-0.0.9-py3-none-any.whl
  • Upload date:
  • Size: 23.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.7.15

File hashes

Hashes for plamtral-0.0.9-py3-none-any.whl
Algorithm Hash digest
SHA256 2f7190ec54a15f004248e1b3af47c06d5484b7d88473bdf71c5c0b2d2b6a6d2d
MD5 f73b100da1d888044b1cd9621dcacea9
BLAKE2b-256 085fb31f41757dcba8697023f60cc2ff5e2714c73ecefcec2c9b423b7c8dec89

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page