Skip to main content

PyTorch Lightning Transformers.

Project description

Flexible components pairing :hugs: Transformers with Pytorch Lightning


DocsCommunity


Installation

pip install lightning-transformers
From Source
git clone https://github.com/PyTorchLightning/lightning-transformers.git
cd lightning-transformers
pip install .

What is Lightning-Transformers

Lightning Transformers provides LightningModules, LightningDataModules and Strategies to use :hugs: Transformers with the PyTorch Lightning Trainer.

Quick Recipes

Train bert-base-cased on the CARER emotion dataset using the Text Classification task.

import pytorch_lightning as pl
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.text_classification import (
    TextClassificationDataModule,
    TextClassificationTransformer,
)

tokenizer = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path="bert-base-cased"
)
dm = TextClassificationDataModule(
    batch_size=1,
    dataset_name="emotion",
    max_length=512,
    tokenizer=tokenizer,
)
model = TextClassificationTransformer(
    pretrained_model_name_or_path="bert-base-cased", num_labels=dm.num_classes
)

trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1)

trainer.fit(model, dm)

Train a pre-trained mt5-base backbone on the WMT16 dataset using the Translation task.

import pytorch_lightning as pl
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.translation import (
    TranslationTransformer,
    WMT16TranslationDataModule,
)

tokenizer = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path="google/mt5-base"
)
model = TranslationTransformer(
    pretrained_model_name_or_path="google/mt5-base",
    n_gram=4,
    smooth=False,
    val_target_max_length=142,
    num_beams=None,
    compute_generate_metrics=True,
)
dm = WMT16TranslationDataModule(
    # WMT translation datasets: ['cs-en', 'de-en', 'fi-en', 'ro-en', 'ru-en', 'tr-en']
    dataset_config_name="ro-en",
    source_language="en",
    target_language="ro",
    max_source_length=128,
    max_target_length=128,
    padding="max_length",
    tokenizer=tokenizer,
)
trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1)

trainer.fit(model, dm)

Lightning Transformers supports a bunch of :hugs: tasks and datasets. See the documentation.

Billion Parameter Model Support

Big Model Inference

It's really easy to enable large model support for the pre-built LightningModule :hugs: tasks.

Below is an example to enable automatic model partitioning (across CPU/GPU and even leveraging disk space) to run text generation using a 6B parameter model.

import torch
from accelerate import init_empty_weights
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.language_modeling import (
    LanguageModelingTransformer,
)

with init_empty_weights():
    model = LanguageModelingTransformer(
        pretrained_model_name_or_path="EleutherAI/gpt-j-6B",
        tokenizer=AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B"),
        low_cpu_mem_usage=True,
        device_map="auto",  # automatically partitions the model based on the available hardware.
    )

output = model.generate("Hello, my name is", device=torch.device("cuda"))
print(model.tokenizer.decode(output[0].tolist()))

For more information see Big Transformers Model Inference.

Big Model Training with DeepSpeed

Below is an example of how you can train a 6B parameter transformer model using Lightning Transformers and DeepSpeed.

import pytorch_lightning as pl
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.language_modeling import (
    LanguageModelingDataModule,
    LanguageModelingTransformer,
)

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="gpt2")

model = LanguageModelingTransformer(
    pretrained_model_name_or_path="EleutherAI/gpt-j-6B",
    tokenizer=AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B"),
    deepspeed_sharding=True,  # defer initialization of the model to shard/load pre-train weights
)

dm = LanguageModelingDataModule(
    batch_size=1,
    dataset_name="wikitext",
    dataset_config_name="wikitext-2-raw-v1",
    tokenizer=tokenizer,
)
trainer = pl.Trainer(
    accelerator="gpu",
    devices="auto",
    strategy="deepspeed_stage_3",
    precision=16,
    max_epochs=1,
)

trainer.fit(model, dm)

For more information see DeepSpeed Training with Big Transformers Models or the Model Parallelism documentation.

Contribute

Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.

Please make sure to update tests as appropriate.

Community

For help or questions, join our huge community on Slack!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lightning-transformers-0.2.2.tar.gz (41.5 kB view details)

Uploaded Source

Built Distribution

lightning_transformers-0.2.2-py3-none-any.whl (76.5 kB view details)

Uploaded Python 3

File details

Details for the file lightning-transformers-0.2.2.tar.gz.

File metadata

File hashes

Hashes for lightning-transformers-0.2.2.tar.gz
Algorithm Hash digest
SHA256 da05f1e67693d441d78d3b90c1c441030edab3a90cc45332f4bcf7e323be2834
MD5 1cf4667dcdb68fc583e3f9d429c20943
BLAKE2b-256 cb776a619bd7d7d2213112a33628c77f72f061352a0627a4707411f752c1dad7

See more details on using hashes here.

File details

Details for the file lightning_transformers-0.2.2-py3-none-any.whl.

File metadata

File hashes

Hashes for lightning_transformers-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 b0972b5013f8000d43a39040546b8898aa2d195f400d7173c7a85b1f5f0480ec
MD5 30c0e163ed02c75d36a8195dd17346e4
BLAKE2b-256 8588ec8bdad4b3de7ef12df12b0e94c3b55f63c050173e3d8e80538c10f8f392

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page