Framework for building training loops easier and faster
Project description
fasttrain
With fasttrain
you'll forever forget about ugly and complex training loops in PyTorch!
Warning!
fasttrain
currently is under heavy development...
How do we start?
Let's create a simple convnet just from the PyTorch tutorial:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
learning_rate = 1e-3
batch_size = 64
epochs = 5
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork()
Your next move will probably be building some kind of training and testing functions to, of course, train your model and show how effective it is, but let's forget about it, and use little help from the Trainer
class:
from fasttrain import Trainer
from fasttrain.metrics import accuracy
class FashionMNISTTrainer(Trainer):
def predict(self, input_batch):
(x_batch, _) = input_batch
return self.model(x_batch)
def compute_loss(self, input_batch, output_batch):
(_, y_batch) = input_batch
return nn.CrossEntropyLoss()(output_batch, y_batch)
def eval_metrics(self, input_batch, output_batch):
(_, y_batch) = input_batch
return {
"accuracy": accuracy(output_batch, y_batch, task="multiclass")
}
With Trainer
all you have to do is specify how you predictions are made, how to compute loss and how to evaluate metrics (I hope you've seen that I've also imported accuracy
metric, isn't it just fancy?). The rest you have to do is specify the model optimizer and call the train
function:
from fasttrain.callbacks import Tqdm
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
trainer = FashionMNISTTrainer(model, optimizer)
history = trainer.train(train_dataloader, val_data=test_dataloader, num_epochs=epochs, callbacks=[Tqdm(colab=True)])
fasttrain
comes with batteries and offers some useful "callbacks" - one of them is Tqdm
which shows a pretty-looking progress bar (colab=True
option is used 'cause I build this network in Google Colab, if you're using it locally you don't need to specify). Let's see how it looks:
Did you see it? The first line printed tells us that we're using cuda - we never mentioned that, did we? Trainer
is smart enough to use cuda if it's enabled, but if you want you can specify device which you want to use in train()
with, for example, option device='cpu'
. train()
also returns us the history of training. What is it? It contains kind of dict which by key returns metrics' statistics over epochs. So you can later use matplotlib to show them. But fasttrain
has a better option: plot them right now!
history.plot("loss", with_val=True)
history.plot("accuracy", with_val=True)
Pretty-looking metrics with graphs, remember, batteries ARE included!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file fasttrain-0.0.4.tar.gz
.
File metadata
- Download URL: fasttrain-0.0.4.tar.gz
- Upload date:
- Size: 19.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.12.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 822153730bd2d879e50576006b9fece22f4ef781803a6d395ce5767df29c07a6 |
|
MD5 | 9f2445c24003ca8f5d73fff5523ec492 |
|
BLAKE2b-256 | 7c00b8a439c370d3e145b03df2f28ebfb4231ca1ff840110ccd5f578d27e59b8 |
File details
Details for the file fasttrain-0.0.4-py3-none-any.whl
.
File metadata
- Download URL: fasttrain-0.0.4-py3-none-any.whl
- Upload date:
- Size: 20.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.12.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | cc8c7a5347861c6ad41a8b7dd648cedac9eb246f7f814dd20bbe8de9e4c68e6a |
|
MD5 | 24241790c3264dd110d6a00ccaf501cb |
|
BLAKE2b-256 | 2d914d5d53e187e4e188248ec357d9ea75bb7efba7a7278cb4ba1f9a264f6e1a |