A python module. The main function is for pytorch training.
Project description
Introduction
This is an individual module, which is mainly for pytorch CNN training.
Moreover, it also supports some awesome features: saving model, saving training process, plotting figures and so on...
Install
pip install fau-tools
Usage
import
The following code is recommended.
import fau_tools
from fau_tools import torch_tools
quick start
The tutor will use a simple example to help you get started quickly!
The following example uses Fau-tools to train a model in MNIST hand-written digits dataset.
import torch
import torch.utils.data as tdata
import torchvision
from torch import nn
import fau_tools
from fau_tools import torch_tools
# A simple CNN network
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 16, 3, 1, 1), # -> (16, 28, 28)
nn.ReLU(),
nn.MaxPool2d(2), # -> (16, 14, 14)
nn.Conv2d(16, 32, 3, 1, 1), # -> (32, 14, 14)
nn.ReLU(),
nn.MaxPool2d(2) # -> (32, 7, 7)
)
self.output = nn.Linear(32 * 7 * 7, 10)
def forward(self, x):
x = self.conv(x)
x = x.flatten(1) # same as x = x.view(x.size(0), -1)
return self.output(x)
# Hyper Parameters definition
total_epoch = 10
lr = 1E-3
batch_size = 1024
# Load dataset
train_data = torchvision.datasets.MNIST('Datasets', True, torchvision.transforms.ToTensor(), download=True)
test_data = torchvision.datasets.MNIST('Datasets', False, torchvision.transforms.ToTensor())
train_data.data = train_data.data[:6000] # mini data
test_data.data = test_data.data[:2000] # mini data
# Get data loader
train_loader = tdata.DataLoader(train_data, batch_size, True)
test_loader = tdata.DataLoader(test_data, batch_size)
# Initialize model, optimizer and loss function
model = CNN()
optimizer = torch.optim.Adam(model.parameters(), lr)
loss_function = nn.CrossEntropyLoss()
# Train!
torch_tools.torch_train(model, train_loader, test_loader, optimizer, loss_function, total_epoch=total_epoch, name="MNIST")
# the last parameter is the name for saving model and training process.
Now, we can run the python file, and the training process will be visualized, just like the following picture.
Three files named
MNIST_9846.pth
,MNIST_9846.csv
andMNIST_9846.txt
will be saved.The first file is the trained model.
The second file records the training process, which you can use matplotlib to visualize it.
The third file saves some hyper parameters about the training.
The above is the primary usage of this tool, but there are also some other snazzy features, which will be introduced later.
END
Hope you could like it! And welcome issues and pull requests.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for fau_tools-1.6.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 883e4577d27456e234731f261bc0c44f88fb06f4484e8be51531c8eb55f1874a |
|
MD5 | bb1ef2831ca1e3249ba68d94561afbcc |
|
BLAKE2b-256 | b3dd49de4374856219a807f3a5a144037d2dc4574f4a08223a4052a553063fb8 |