PyTorch Lightning Transformers.
Project description
Installation
pip install lightning-transformers
From Source
git clone https://github.com/PyTorchLightning/lightning-transformers.git
cd lightning-transformers
pip install .
What is Lightning-Transformers
Lightning Transformers provides LightningModules
, LightningDataModules
and Strategies
to use :hugs: Transformers with the PyTorch Lightning Trainer.
Quick Recipes
Train bert-base-cased on the CARER emotion dataset using the Text Classification task.
import pytorch_lightning as pl
from transformers import AutoTokenizer
from lightning_transformers.task.nlp.text_classification import (
TextClassificationDataModule,
TextClassificationTransformer,
TextClassificationDataConfig,
)
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path="bert-base-cased"
)
dm = TextClassificationDataModule(
cfg=TextClassificationDataConfig(
batch_size=1,
dataset_name="emotion",
max_length=512,
),
tokenizer=tokenizer,
)
model = TextClassificationTransformer(
pretrained_model_name_or_path="bert-base-cased", num_labels=dm.num_classes
)
trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1)
trainer.fit(model, dm)
Train a pre-trained mt5-base backbone on the WMT16 dataset using the Translation task.
import pytorch_lightning as pl
from transformers import AutoTokenizer
from lightning_transformers.task.nlp.translation import (
TranslationTransformer,
WMT16TranslationDataModule,
TranslationConfig,
TranslationDataConfig,
)
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path="google/mt5-base"
)
model = TranslationTransformer(
pretrained_model_name_or_path="google/mt5-base",
cfg=TranslationConfig(
n_gram=4,
smooth=False,
val_target_max_length=142,
num_beams=None,
compute_generate_metrics=True,
),
)
dm = WMT16TranslationDataModule(
cfg=TranslationDataConfig(
dataset_name="wmt16",
# WMT translation datasets: ['cs-en', 'de-en', 'fi-en', 'ro-en', 'ru-en', 'tr-en']
dataset_config_name="ro-en",
source_language="en",
target_language="ro",
max_source_length=128,
max_target_length=128,
),
tokenizer=tokenizer,
)
trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1)
trainer.fit(model, dm)
Lightning Transformers supports a bunch of :hugs: tasks and datasets. See the documentation.
Contribute
Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.
Please make sure to update tests as appropriate.
Community
For help or questions, join our huge community on Slack!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file lightning-transformers-0.2.0rc1.tar.gz
.
File metadata
- Download URL: lightning-transformers-0.2.0rc1.tar.gz
- Upload date:
- Size: 49.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.9.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 80867852dea871b704406c65782d2afe2d44bcff60070b91815043298cbb8171 |
|
MD5 | ccf37a592415b539ccb22ad0f99ca9a1 |
|
BLAKE2b-256 | 86e1b78f978ecfeb57926a6061607afb0ec21e666f5a09a85307805cd7b52c37 |
File details
Details for the file lightning_transformers-0.2.0rc1-py3-none-any.whl
.
File metadata
- Download URL: lightning_transformers-0.2.0rc1-py3-none-any.whl
- Upload date:
- Size: 109.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.9.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 98bebce3c4008c846b3999d94aac00e30228c69ae4bcfbc79142154e7502fffc |
|
MD5 | 7015f2e2ab10fb3bb6f532fe2d25a90c |
|
BLAKE2b-256 | e221ddafa95af6e80c704f74f0671b2c43f3837df20e483ce667b40009ef62a8 |