Skip to main content

PyTorch Lightning Transformers.

Project description

Flexible components pairing :hugs: Transformers with Pytorch Lightning


DocsCommunity


Installation

pip install lightning-transformers
From Source
git clone https://github.com/PyTorchLightning/lightning-transformers.git
cd lightning-transformers
pip install .

What is Lightning-Transformers

Lightning Transformers provides LightningModules, LightningDataModules and Strategies to use :hugs: Transformers with the PyTorch Lightning Trainer.

Quick Recipes

Train bert-base-cased on the CARER emotion dataset using the Text Classification task.

import pytorch_lightning as pl
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.text_classification import (
    TextClassificationDataModule,
    TextClassificationTransformer,
)

tokenizer = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path="bert-base-cased"
)
dm = TextClassificationDataModule(
    batch_size=1,
    dataset_name="emotion",
    max_length=512,
    tokenizer=tokenizer,
)
model = TextClassificationTransformer(
    pretrained_model_name_or_path="bert-base-cased", num_labels=dm.num_classes
)

trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1)

trainer.fit(model, dm)

Train a pre-trained mt5-base backbone on the WMT16 dataset using the Translation task.

import pytorch_lightning as pl
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.translation import (
    TranslationTransformer,
    WMT16TranslationDataModule,
)

tokenizer = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path="google/mt5-base"
)
model = TranslationTransformer(
    pretrained_model_name_or_path="google/mt5-base",
    n_gram=4,
    smooth=False,
    val_target_max_length=142,
    num_beams=None,
    compute_generate_metrics=True,
)
dm = WMT16TranslationDataModule(
    # WMT translation datasets: ['cs-en', 'de-en', 'fi-en', 'ro-en', 'ru-en', 'tr-en']
    dataset_config_name="ro-en",
    source_language="en",
    target_language="ro",
    max_source_length=128,
    max_target_length=128,
    padding="max_length",
    tokenizer=tokenizer,
)
trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1)

trainer.fit(model, dm)

Lightning Transformers supports a bunch of :hugs: tasks and datasets. See the documentation.

Contribute

Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.

Please make sure to update tests as appropriate.

Community

For help or questions, join our huge community on Slack!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lightning-transformers-0.2.0.tar.gz (40.4 kB view details)

Uploaded Source

Built Distribution

lightning_transformers-0.2.0-py3-none-any.whl (74.8 kB view details)

Uploaded Python 3

File details

Details for the file lightning-transformers-0.2.0.tar.gz.

File metadata

File hashes

Hashes for lightning-transformers-0.2.0.tar.gz
Algorithm Hash digest
SHA256 f847f1cd54c611bd637d07b14f0325248eeaa07da0574ee9f5a592842a63913f
MD5 57d7a04cfcea516cf26f83c6b5a2198c
BLAKE2b-256 b547e586520f32aa9cf6eae144b00837d90fddfd388d305ae1ffd836ed3bf0a9

See more details on using hashes here.

File details

Details for the file lightning_transformers-0.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for lightning_transformers-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 503d213334dc12bbc1305de16acc111ecedcd26f74682ff57df9c3c247d616ed
MD5 697dc9015aa10f083352c581eb4f1845
BLAKE2b-256 b62d1de59ff34da3708e550d3f9b1b5df1959c0118256a7dff215d032bb488fa

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page