Skip to main content

PyTorch Lightning Transformers.

Project description

Flexible components pairing :hugs: Transformers with Pytorch Lightning


DocsCommunity


Installation

pip install lightning-transformers
From Source
git clone https://github.com/PyTorchLightning/lightning-transformers.git
cd lightning-transformers
pip install .

What is Lightning-Transformers

Lightning Transformers provides LightningModules, LightningDataModules and Strategies to use :hugs: Transformers with the PyTorch Lightning Trainer.

Quick Recipes

Train bert-base-cased on the CARER emotion dataset using the Text Classification task.

import pytorch_lightning as pl
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.text_classification import (
    TextClassificationDataModule,
    TextClassificationTransformer,
    TextClassificationDataConfig,
)

tokenizer = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path="bert-base-cased"
)
dm = TextClassificationDataModule(
    cfg=TextClassificationDataConfig(
        batch_size=1,
        dataset_name="emotion",
        max_length=512,
    ),
    tokenizer=tokenizer,
)
model = TextClassificationTransformer(
    pretrained_model_name_or_path="bert-base-cased", num_labels=dm.num_classes
)

trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1)

trainer.fit(model, dm)

Train a pre-trained mt5-base backbone on the WMT16 dataset using the Translation task.

import pytorch_lightning as pl
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.translation import (
    TranslationTransformer,
    WMT16TranslationDataModule,
    TranslationConfig,
    TranslationDataConfig,
)

tokenizer = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path="google/mt5-base"
)
model = TranslationTransformer(
    pretrained_model_name_or_path="google/mt5-base",
    cfg=TranslationConfig(
        n_gram=4,
        smooth=False,
        val_target_max_length=142,
        num_beams=None,
        compute_generate_metrics=True,
    ),
)
dm = WMT16TranslationDataModule(
    cfg=TranslationDataConfig(
        dataset_name="wmt16",
        # WMT translation datasets: ['cs-en', 'de-en', 'fi-en', 'ro-en', 'ru-en', 'tr-en']
        dataset_config_name="ro-en",
        source_language="en",
        target_language="ro",
        max_source_length=128,
        max_target_length=128,
    ),
    tokenizer=tokenizer,
)
trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1)

trainer.fit(model, dm)

Lightning Transformers supports a bunch of :hugs: tasks and datasets. See the documentation.

Contribute

Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.

Please make sure to update tests as appropriate.

Community

For help or questions, join our huge community on Slack!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lightning-transformers-0.2.0rc1.tar.gz (49.9 kB view details)

Uploaded Source

Built Distribution

lightning_transformers-0.2.0rc1-py3-none-any.whl (109.9 kB view details)

Uploaded Python 3

File details

Details for the file lightning-transformers-0.2.0rc1.tar.gz.

File metadata

File hashes

Hashes for lightning-transformers-0.2.0rc1.tar.gz
Algorithm Hash digest
SHA256 80867852dea871b704406c65782d2afe2d44bcff60070b91815043298cbb8171
MD5 ccf37a592415b539ccb22ad0f99ca9a1
BLAKE2b-256 86e1b78f978ecfeb57926a6061607afb0ec21e666f5a09a85307805cd7b52c37

See more details on using hashes here.

File details

Details for the file lightning_transformers-0.2.0rc1-py3-none-any.whl.

File metadata

File hashes

Hashes for lightning_transformers-0.2.0rc1-py3-none-any.whl
Algorithm Hash digest
SHA256 98bebce3c4008c846b3999d94aac00e30228c69ae4bcfbc79142154e7502fffc
MD5 7015f2e2ab10fb3bb6f532fe2d25a90c
BLAKE2b-256 e221ddafa95af6e80c704f74f0671b2c43f3837df20e483ce667b40009ef62a8

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page