Skip to main content

Train 🤗-transformers models with Poutyne.

Project description

poutyne-transformers

Train 🤗-transformers models with Poutyne.

Installation

pip install poutyne-transformers

Example

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch import optim
from poutyne import Model, Accuracy
from poutyne_transformers import (
    TransformerCollator,
    model_loss,
    ModelWrapper,
    MetricWrapper,
)

print("Loading model & tokenizer.")
transformer = AutoModelForSequenceClassification.from_pretrained(
    "distilbert-base-cased", num_labels=2, return_dict=True
)
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")

print("Loading & preparing dataset.")
dataset = load_dataset("imdb")
dataset = dataset.map(
    lambda entry: tokenizer(
        entry["text"], add_special_tokens=True, padding="max_length", truncation=True
    ),
    batched=True,
)
dataset = dataset.remove_columns(["text"])
dataset = dataset.shuffle()
dataset.set_format("torch")

collate_fn = TransformerCollator(y_keys="labels")
train_dataloader = DataLoader(dataset["train"], batch_size=16, collate_fn=collate_fn)
test_dataloader = DataLoader(dataset["test"], batch_size=16, collate_fn=collate_fn)

print("Preparing training.")
wrapped_transformer = ModelWrapper(transformer)
optimizer = optim.AdamW(wrapped_transformer.parameters(), lr=5e-5)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
accuracy = MetricWrapper(Accuracy(), pred_key="logits")
model = Model(
    wrapped_transformer,
    optimizer,
    loss_function=model_loss,
    batch_metrics=[accuracy],
    device=device,
)

print("Starting training.")
model.fit_generator(train_dataloader, test_dataloader, epochs=1)

You can also create models with a custom architecture using torch.nn.Sequential class:

from torch import nn
from transformers import AutoModel
from poutyne import Lambda
from poutyne_transformers import ModelWrapper

...

transformer = AutoModel.from_pretrained(
    "distilbert-base-cased", output_hidden_states=True
)

custom_model = nn.Sequential(
    ModelWrapper(transformer),
    # Use distilberts [CLS] token for classification.
    Lambda(lambda outputs: outputs["last_hidden_state"][:, 0, :]),
    nn.Linear(in_features=transformer.config.hidden_size, out_features=1),
    Lambda(lambda out: out.reshape(-1)),
)

...

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

poutyne-transformers-0.1.0.4.tar.gz (7.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

poutyne_transformers-0.1.0.4-py3-none-any.whl (8.4 kB view details)

Uploaded Python 3

File details

Details for the file poutyne-transformers-0.1.0.4.tar.gz.

File metadata

  • Download URL: poutyne-transformers-0.1.0.4.tar.gz
  • Upload date:
  • Size: 7.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.1.10 CPython/3.9.7 Darwin/20.6.0

File hashes

Hashes for poutyne-transformers-0.1.0.4.tar.gz
Algorithm Hash digest
SHA256 8cb64e7c2ae1d9c3dbe098a4b3aa9340a0763ec52c6af9ae46995ee15fef5042
MD5 fbb1bb9cfadafbb8a4ea73657c6e88f9
BLAKE2b-256 cee896d5ceb56a5d5f47f966c447b7d9d7f9a1aff7509a41dfcf0cb232eba152

See more details on using hashes here.

File details

Details for the file poutyne_transformers-0.1.0.4-py3-none-any.whl.

File metadata

File hashes

Hashes for poutyne_transformers-0.1.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 c30870a957b2d0797a0dee611740dc58c63d5954b27b45e1c45add36bf0db1fa
MD5 52d12666cfe541ab63f4bb98d8e770a5
BLAKE2b-256 5113e8d12badbb082b570f74f7a296ee0755111388a1ec093e6014457f726898

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page