Skip to main content

A python module. The main function is for pytorch training.

Project description

Introduction

This is a small tool that uses the PyTorch framework, providing assistance in completing classification task using CNN.

Features: train model, print training process, save training files, plot figures, etc.

Install

pip install fau-tools

Usage

import

The following code is recommended.

import fau_tools

quick start

The tutor will use a simple example to help you get started quickly!

The following example uses Fau-tools to train a model in MNIST hand-written digits dataset.

import torch
import torch.nn as nn
import torch.utils.data as tdata
import torchvision

import fau_tools


# A simple CNN network
class CNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv = nn.Sequential(
      nn.Conv2d(1, 16, 3, 1, 1),  # -> (16, 28, 28)
      nn.ReLU(),
      nn.MaxPool2d(2),  # -> (16, 14, 14)

      nn.Conv2d(16, 32, 3, 1, 1),  # -> (32, 14, 14)
      nn.ReLU(),
      nn.MaxPool2d(2)  # -> (32, 7, 7)
    )
    self.output = nn.Linear(32 * 7 * 7, 10)


  def forward(self, x):
    x = self.conv(x)
    x = x.flatten(1)
    return self.output(x)


# Hyper Parameters definition
total_epoch = 10
lr = 1E-2
batch_size = 1024

# Load dataset
train_data      = torchvision.datasets.MNIST('datasets', True, torchvision.transforms.ToTensor(), download=True)
test_data       = torchvision.datasets.MNIST('datasets', False, torchvision.transforms.ToTensor())
train_data.data = train_data.data[:6000]  # mini data
test_data.data  = test_data.data[:2000]  # mini data

# Get data loader
train_loader = tdata.DataLoader(train_data, batch_size, True)
test_loader  = tdata.DataLoader(test_data, batch_size)

# Initialize model, optimizer and loss function
model = CNN()
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr)

# Train!
fau_tools.TaskRunner(model, train_loader, test_loader, loss_function, optimizer, total_epoch, exp_path="MNIST").train()

Now, we can run the python script, and the training process will be visualized as the following picture.

training_visualization

Three files named best.pth, scalars.csv and exp_info.txt will be saved.

The first file is the weight of trained model.

The second file records scalar value changes in the training process.

The third file saves information about the experiment.


The above is the primary usage of this tool, but there are also some other snazzy features, which will be introduced later. [TODO]

END

Hope you could like it! And welcome issues and pull requests.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fau_tools-2.0.4.tar.gz (12.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fau_tools-2.0.4-py3-none-any.whl (14.9 kB view details)

Uploaded Python 3

File details

Details for the file fau_tools-2.0.4.tar.gz.

File metadata

  • Download URL: fau_tools-2.0.4.tar.gz
  • Upload date:
  • Size: 12.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.7.1 CPython/3.10.13 Darwin/23.3.0

File hashes

Hashes for fau_tools-2.0.4.tar.gz
Algorithm Hash digest
SHA256 6713376438d163d0367987045a9655d24f10872eabb769c2523101a9f0068221
MD5 2c219a2f0d033fe3e8389a2daf3a55a6
BLAKE2b-256 fbfd916eec48f4afe26ec2e5f86ac8110346b52b5c29beb18cfd9f691a030463

See more details on using hashes here.

File details

Details for the file fau_tools-2.0.4-py3-none-any.whl.

File metadata

  • Download URL: fau_tools-2.0.4-py3-none-any.whl
  • Upload date:
  • Size: 14.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.7.1 CPython/3.10.13 Darwin/23.3.0

File hashes

Hashes for fau_tools-2.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 04739549df8329e750dd26e4bff054c4039d53425d8f78b179e94dbbd5457425
MD5 3c105b6d08c79c702d6b0f343c74cd31
BLAKE2b-256 6bf8533137c0ca8c1d295549741fe31851a83e3a2ca2bc2c89c0bbf0b069258e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page