Framework for building training loops easier and faster
Project description
fasttrain
fasttrain
is a lightweight framework for building training loops for neural nets as fast as possible. It's designed to remove all boring details about making up training loops in PyTorch, so you don't have to concentrate on how to pretty print a loss or metrics or bother about how to calculate them right.
Installation
$ pip install fasttrain
How do we start?
Let's use a neural network to classify images in the FashionMNIST dataset:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
learning_rate = 1e-3
batch_size = 64
epochs = 5
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork()
Then we make up a trainer:
from fasttrain import Trainer
from fasttrain.metrics import accuracy
class MyTrainer(Trainer):
# Define how we compute the loss
def compute_loss(self, input_batch, output_batch):
(_, y_batch) = input_batch
return nn.CrossEntropyLoss()(output_batch, y_batch)
# Define how we compute metrics
def eval_metrics(self, input_batch, output_batch):
(_, y_batch) = input_batch
return {
"accuracy": accuracy(output_batch, y_batch, task="multiclass")
}
Finally, let's train our model:
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
trainer = MyTrainer(model, optimizer)
history = trainer.train(train_dataloader, val_data=test_dataloader, num_epochs=epochs)
fasttrain
offers some useful callbacks - one of them is Tqdm
which shows a pretty-looking progress bar:
Trainer.train()
returns the history of training - it contains a dict which stores metrics over epochs and can plot them:
history.plot("loss", with_val=True)
history.plot("accuracy", with_val=True)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file fasttrain-0.0.7.tar.gz
.
File metadata
- Download URL: fasttrain-0.0.7.tar.gz
- Upload date:
- Size: 18.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.12.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6a7248f803c50941e41a1e60b455efa729d73739a44ac51fc9f2f822c545dc70 |
|
MD5 | 5686cf05c4dcf9be491287896c4e2df1 |
|
BLAKE2b-256 | 4a4626cd6a54a32fe7078e979fe741e24d05fec245ddd9c11416b1fc36385d8b |
File details
Details for the file fasttrain-0.0.7-py3-none-any.whl
.
File metadata
- Download URL: fasttrain-0.0.7-py3-none-any.whl
- Upload date:
- Size: 20.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.12.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d20b636ebb11a6ec6e40bc714904d27bd94c7491c19e9b5bdd92c9f5f9b8a5e2 |
|
MD5 | b7202c0d7da2d099c306fd722ac945b3 |
|
BLAKE2b-256 | 861c904a6a41cea00ce0b2369be4a1bcf7bf12d51581a5b134f0ec65bc9c32fb |