Skip to main content

Framework for building training loops easier and faster

Project description

fasttrain

fasttrain is a lightweight framework for building training loops for neural nets as fast as possible. It's designed to remove all boring details about making up training loops in PyTorch, so you don't have to concentrate on how to pretty print a loss or metrics or bother about how to calculate them right.

Installation

$ pip install fasttrain

How do we start?

Let's use a neural network to classify images in the FashionMNIST dataset:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

learning_rate = 1e-3
batch_size = 64
epochs = 5

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork()

Then we make up a trainer:

from fasttrain import Trainer
from fasttrain.metrics import accuracy

class MyTrainer(Trainer):

    # Define how we compute the loss
    def compute_loss(self, input_batch, output_batch):
        (_, y_batch) = input_batch
        return nn.CrossEntropyLoss()(output_batch, y_batch)

    # Define how we compute metrics
    def eval_metrics(self, input_batch, output_batch):
        (_, y_batch) = input_batch
        return {
            "accuracy": accuracy(output_batch, y_batch, task="multiclass")
        }

Finally, let's train our model:

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
trainer = MyTrainer(model, optimizer)
history = trainer.train(train_dataloader, val_data=test_dataloader, num_epochs=epochs)

fasttrain offers some useful callbacks - one of them is Tqdm which shows a pretty-looking progress bar: training_loop

Trainer.train() returns the history of training - it contains a dict which stores metrics over epochs and can plot them:

history.plot("loss", with_val=True)

loss

history.plot("accuracy", with_val=True)

accuracy

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fasttrain-0.0.7.tar.gz (18.9 kB view details)

Uploaded Source

Built Distribution

fasttrain-0.0.7-py3-none-any.whl (20.8 kB view details)

Uploaded Python 3

File details

Details for the file fasttrain-0.0.7.tar.gz.

File metadata

  • Download URL: fasttrain-0.0.7.tar.gz
  • Upload date:
  • Size: 18.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.1

File hashes

Hashes for fasttrain-0.0.7.tar.gz
Algorithm Hash digest
SHA256 6a7248f803c50941e41a1e60b455efa729d73739a44ac51fc9f2f822c545dc70
MD5 5686cf05c4dcf9be491287896c4e2df1
BLAKE2b-256 4a4626cd6a54a32fe7078e979fe741e24d05fec245ddd9c11416b1fc36385d8b

See more details on using hashes here.

File details

Details for the file fasttrain-0.0.7-py3-none-any.whl.

File metadata

  • Download URL: fasttrain-0.0.7-py3-none-any.whl
  • Upload date:
  • Size: 20.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.1

File hashes

Hashes for fasttrain-0.0.7-py3-none-any.whl
Algorithm Hash digest
SHA256 d20b636ebb11a6ec6e40bc714904d27bd94c7491c19e9b5bdd92c9f5f9b8a5e2
MD5 b7202c0d7da2d099c306fd722ac945b3
BLAKE2b-256 861c904a6a41cea00ce0b2369be4a1bcf7bf12d51581a5b134f0ec65bc9c32fb

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page