Skip to main content

Framework for building training loops easier and faster

Project description

fasttrain

With fasttrain you'll forever forget about ugly and complex training loops in PyTorch!

Warning!

fasttrain currently is under heavy development...

How do we start?

Let's create a simple convnet just from the PyTorch tutorial:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

learning_rate = 1e-3
batch_size = 64
epochs = 5

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork()

Your next move will probably be building some kind of training and testing functions to, of course, train your model and show how effective it is, but let's forget about it, and use little help from the Trainer class:

from fasttrain import Trainer
from fasttrain.metrics import accuracy

class FashionMNISTTrainer(Trainer):

    def predict(self, input_batch):
        (x_batch, _) = input_batch
        return self.model(x_batch)

    def compute_loss(self, input_batch, output_batch):
        (_, y_batch) = input_batch
        return nn.CrossEntropyLoss()(output_batch, y_batch)

    def eval_metrics(self, input_batch, output_batch):
        (_, y_batch) = input_batch
        return {
            "accuracy": accuracy(output_batch, y_batch, task="multiclass")
        }

With Trainer all you have to do is specify how you predictions are made, how to compute loss and how to evaluate metrics (I hope you've seen that I've also imported accuracy metric, isn't it just fancy?). The rest you have to do is specify the model optimizer and call the train function:

from fasttrain.callbacks import Tqdm

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
trainer = FashionMNISTTrainer(model, optimizer)
history = trainer.train(train_dataloader, val_data=test_dataloader, num_epochs=epochs, callbacks=[Tqdm(colab=True)])

fasttrain comes with batteries and offers some useful "callbacks" - one of them is Tqdm which shows a pretty-looking progress bar (colab=True option is used 'cause I build this network in Google Colab, if you're using it locally you don't need to specify). Let's see how it looks: training_loop

Did you see it? The first line printed tells us that we're using cuda - we never mentioned that, did we? Trainer is smart enough to use cuda if it's enabled, but if you want you can specify device which you want to use in train() with, for example, option device='cpu'. train() also returns us the history of training. What is it? It contains kind of dict which by key returns metrics' statistics over epochs. So you can later use matplotlib to show them. But fasttrain has a better option: plot them right now!

history.plot("loss", with_val=True)

loss

history.plot("accuracy", with_val=True)

accuracy

Pretty-looking metrics with graphs, remember, batteries ARE included!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fasttrain-0.0.4.tar.gz (19.1 kB view details)

Uploaded Source

Built Distribution

fasttrain-0.0.4-py3-none-any.whl (20.5 kB view details)

Uploaded Python 3

File details

Details for the file fasttrain-0.0.4.tar.gz.

File metadata

  • Download URL: fasttrain-0.0.4.tar.gz
  • Upload date:
  • Size: 19.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.1

File hashes

Hashes for fasttrain-0.0.4.tar.gz
Algorithm Hash digest
SHA256 822153730bd2d879e50576006b9fece22f4ef781803a6d395ce5767df29c07a6
MD5 9f2445c24003ca8f5d73fff5523ec492
BLAKE2b-256 7c00b8a439c370d3e145b03df2f28ebfb4231ca1ff840110ccd5f578d27e59b8

See more details on using hashes here.

File details

Details for the file fasttrain-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: fasttrain-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 20.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.1

File hashes

Hashes for fasttrain-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 cc8c7a5347861c6ad41a8b7dd648cedac9eb246f7f814dd20bbe8de9e4c68e6a
MD5 24241790c3264dd110d6a00ccaf501cb
BLAKE2b-256 2d914d5d53e187e4e188248ec357d9ea75bb7efba7a7278cb4ba1f9a264f6e1a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page