Skip to main content

Framework for building training loops easier and faster

Project description

fasttrain

Forget about ugly and complex training loops

Installation

pip install fasttrain

Warning!

fasttrain currently is under heavy development...

How do we start?

Let's create a simple convnet just from the PyTorch tutorial:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

learning_rate = 1e-3
batch_size = 64
epochs = 5

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork()

Your next move will probably be building some kind of training and testing functions to, of course, train your model and show how effective it is, but let's forget about it, and use little help from the Trainer class:

from fasttrain import Trainer
from fasttrain.metrics import accuracy

class FashionMNISTTrainer(Trainer):

    def predict(self, input_batch):
        (x_batch, _) = input_batch
        return self.model(x_batch)

    def compute_loss(self, input_batch, output_batch):
        (_, y_batch) = input_batch
        return nn.CrossEntropyLoss()(output_batch, y_batch)

    def eval_metrics(self, input_batch, output_batch):
        (_, y_batch) = input_batch
        return {
            "accuracy": accuracy(output_batch, y_batch, task="multiclass")
        }

With Trainer all you have to do is specify how you predictions are made, how to compute loss and how to evaluate metrics (I hope you've seen that I've also imported accuracy metric, isn't it just fancy?). The rest you have to do is specify the model optimizer and call the train function:

from fasttrain.callbacks import Tqdm

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
trainer = FashionMNISTTrainer(model, optimizer)
history = trainer.train(train_dataloader, val_data=test_dataloader, num_epochs=epochs, callbacks=[Tqdm(colab=True)])

fasttrain comes with batteries and offers some useful "callbacks" - one of them is Tqdm which shows a pretty-looking progress bar (colab=True option is used 'cause I build this network in Google Colab, if you're using it locally you don't need to specify). Let's see how it looks: training_loop

Did you see it? The first line printed tells us that we're using cuda - we never mentioned that, did we? Trainer is smart enough to use cuda if it's enabled, but if you want you can specify device which you want to use in train() with, for example, option device='cpu'. train() also returns us the history of training. What is it? It contains kind of dict which by key returns metrics' statistics over epochs. So you can later use matplotlib to show them. But fasttrain has a better option: plot them right now!

history.plot("loss", with_val=True)

loss

history.plot("accuracy", with_val=True)

accuracy

Pretty-looking metrics with graphs, remember, batteries ARE included!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fasttrain-0.0.6.tar.gz (20.5 kB view details)

Uploaded Source

Built Distribution

fasttrain-0.0.6-py3-none-any.whl (22.0 kB view details)

Uploaded Python 3

File details

Details for the file fasttrain-0.0.6.tar.gz.

File metadata

  • Download URL: fasttrain-0.0.6.tar.gz
  • Upload date:
  • Size: 20.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.1

File hashes

Hashes for fasttrain-0.0.6.tar.gz
Algorithm Hash digest
SHA256 d5f4bc00a3b0dd497b81089a88a751bd83019bff05b76bccf61c1fc6f9e28098
MD5 2e02b65fddaa4c7bb81968a255f99866
BLAKE2b-256 fab219dcd7ec55ee20374fc3873f30265d555cee7feb303c811d70be26000073

See more details on using hashes here.

File details

Details for the file fasttrain-0.0.6-py3-none-any.whl.

File metadata

  • Download URL: fasttrain-0.0.6-py3-none-any.whl
  • Upload date:
  • Size: 22.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.1

File hashes

Hashes for fasttrain-0.0.6-py3-none-any.whl
Algorithm Hash digest
SHA256 28cee72427ec65a22acc24d888bc5227f2752132c7d95b11375e04f19d6d67f4
MD5 b9944e53265874f1078bf87f27e30f0c
BLAKE2b-256 b317864d868af9c6702ce6d1a90c52994d981e927a1e8023bc06f880df121381

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page