Skip to main content

PyTorch Lightning Transformers.

Project description

Flexible components pairing :hugs: Transformers with Pytorch Lightning


DocsCommunity


Installation

pip install lightning-transformers
From Source
git clone https://github.com/PyTorchLightning/lightning-transformers.git
cd lightning-transformers
pip install .

What is Lightning-Transformers

Lightning Transformers provides LightningModules, LightningDataModules and Strategies to use :hugs: Transformers with the PyTorch Lightning Trainer.

Quick Recipes

Train bert-base-cased on the CARER emotion dataset using the Text Classification task.

import pytorch_lightning as pl
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.text_classification import (
    TextClassificationDataModule,
    TextClassificationTransformer,
)

tokenizer = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path="bert-base-cased"
)
dm = TextClassificationDataModule(
    batch_size=1,
    dataset_name="emotion",
    max_length=512,
    tokenizer=tokenizer,
)
model = TextClassificationTransformer(
    pretrained_model_name_or_path="bert-base-cased", num_labels=dm.num_classes
)

trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1)

trainer.fit(model, dm)

Train a pre-trained mt5-base backbone on the WMT16 dataset using the Translation task.

import pytorch_lightning as pl
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.translation import (
    TranslationTransformer,
    WMT16TranslationDataModule,
)

tokenizer = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path="google/mt5-base"
)
model = TranslationTransformer(
    pretrained_model_name_or_path="google/mt5-base",
    n_gram=4,
    smooth=False,
    val_target_max_length=142,
    num_beams=None,
    compute_generate_metrics=True,
)
dm = WMT16TranslationDataModule(
    # WMT translation datasets: ['cs-en', 'de-en', 'fi-en', 'ro-en', 'ru-en', 'tr-en']
    dataset_config_name="ro-en",
    source_language="en",
    target_language="ro",
    max_source_length=128,
    max_target_length=128,
    padding="max_length",
    tokenizer=tokenizer,
)
trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1)

trainer.fit(model, dm)

Lightning Transformers supports a bunch of :hugs: tasks and datasets. See the documentation.

Billion Parameter Model Support

Big Model Inference

It's really easy to enable large model support for the pre-built LightningModule :hugs: tasks.

Below is an example to enable automatic model partitioning (across CPU/GPU and even leveraging disk space) to run text generation using a 6B parameter model.

import torch
from accelerate import init_empty_weights
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.language_modeling import (
    LanguageModelingTransformer,
)

with init_empty_weights():
    model = LanguageModelingTransformer(
        pretrained_model_name_or_path="EleutherAI/gpt-j-6B",
        tokenizer=AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B"),
        low_cpu_mem_usage=True,
        device_map="auto",  # automatically partitions the model based on the available hardware.
    )

output = model.generate("Hello, my name is", device=torch.device("cuda"))
print(model.tokenizer.decode(output[0].tolist()))

For more information see Big Transformers Model Inference.

Big Model Training with DeepSpeed

Below is an example of how you can train a 6B parameter transformer model using Lightning Transformers and DeepSpeed.

import pytorch_lightning as pl
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.language_modeling import (
    LanguageModelingDataModule,
    LanguageModelingTransformer,
)

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="gpt2")

model = LanguageModelingTransformer(
    pretrained_model_name_or_path="EleutherAI/gpt-j-6B",
    tokenizer=AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B"),
    deepspeed_sharding=True,  # defer initialization of the model to shard/load pre-train weights
)

dm = LanguageModelingDataModule(
    batch_size=1,
    dataset_name="wikitext",
    dataset_config_name="wikitext-2-raw-v1",
    tokenizer=tokenizer,
)
trainer = pl.Trainer(
    accelerator="gpu",
    devices="auto",
    strategy="deepspeed_stage_3",
    precision=16,
    max_epochs=1,
)

trainer.fit(model, dm)

For more information see DeepSpeed Training with Big Transformers Models or the Model Parallelism documentation.

Contribute

Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.

Please make sure to update tests as appropriate.

Community

For help or questions, join our huge community on Slack!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lightning-transformers-0.2.3.tar.gz (41.5 kB view details)

Uploaded Source

Built Distribution

lightning_transformers-0.2.3-py3-none-any.whl (76.5 kB view details)

Uploaded Python 3

File details

Details for the file lightning-transformers-0.2.3.tar.gz.

File metadata

File hashes

Hashes for lightning-transformers-0.2.3.tar.gz
Algorithm Hash digest
SHA256 a7bb1d61e5ccbfa0683f1f406b43363cf71676ac4b72b50bb444a203ad564a67
MD5 e70cf61769dd1dc96c8075d73ce415a8
BLAKE2b-256 cebe3953551e634b8928328d289b1b4f5c313e0542f8d2aa54a412d0cec8a499

See more details on using hashes here.

File details

Details for the file lightning_transformers-0.2.3-py3-none-any.whl.

File metadata

File hashes

Hashes for lightning_transformers-0.2.3-py3-none-any.whl
Algorithm Hash digest
SHA256 a1e17945373fc5e3b2abeb7838f647c5608ac3a42f639a36631fec07b336972f
MD5 f876806526a18dc066ae1e9121ed325b
BLAKE2b-256 1a053dc795494d7c762d286351dbe23d8457fed4b7591264b7cd62452bcf5e35

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page