Skip to main content
Python Software Foundation 20th Year Anniversary Fundraiser  Donate today!

A library of toandaominh1997

Project description

Library of pytoan

Introduction

Installing

pip install pytoan

Usage

  1. Example model with MNIST
from pytoan.pytorch import Learning
import torch 
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from pathlib import Path

# Hyper parameters
num_classes = 10
batch_size = 100
learning_rate = 0.001

# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data/',
                                           train=True, 
                                           transform=transforms.ToTensor(),
                                           download=True)
test_dataset = torchvision.datasets.MNIST(root='../../data/',
                                          train=False, 
                                          transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True,
                                           pin_memory=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size, 
                                          shuffle=False,
                                          pin_memory=True)

class ConvNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7*7*32, num_classes)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out

def accuracy_score(output, target):
    with torch.no_grad():
        pred = torch.argmax(output, dim=1)
        assert pred.shape[0] == len(target)
        correct = 0
        correct += torch.sum(pred == target).item()
    return correct / len(target)

model = ConvNet(num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
metric_ftns = [accuracy_score]
device = [0]
num_epoch = 100
gradient_clipping = 0.1
gradient_accumulation_steps = 1
early_stopping = 10
validation_frequency = 1
tensorboard = True
checkpoint_dir = Path('./', type(model).__name__)
checkpoint_dir.mkdir(exist_ok=True, parents=True)
resume_path = None
learning = Learning(model=model,
                    criterion=criterion,
                    optimizer=optimizer,
                    scheduler = scheduler,
                    metric_ftns=metric_ftns,
                    device=device,
                    num_epoch=num_epoch,
                    grad_clipping = gradient_clipping,
                    grad_accumulation_steps = gradient_accumulation_steps,
                    early_stopping = early_stopping,
                    validation_frequency = validation_frequency,
                    tensorboard = tensorboard,
                    checkpoint_dir = checkpoint_dir,
                    resume_path=resume_path)
  1. For Training and Validation
learning.train(train_loader, test_loader)

Log:

MNIST_EXAMPLE

  1. For Testing
learning.test(test_loader) # but not complete

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for pytoan, version 0.7.2.2.9
Filename, size File type Python version Upload date Hashes
Filename, size pytoan-0.7.2.2.9-py3-none-any.whl (13.0 kB) File type Wheel Python version py3 Upload date Hashes View
Filename, size pytoan-0.7.2.2.9.tar.gz (11.0 kB) File type Source Python version None Upload date Hashes View

Supported by

AWS AWS Cloud computing Datadog Datadog Monitoring DigiCert DigiCert EV certificate Facebook / Instagram Facebook / Instagram PSF Sponsor Fastly Fastly CDN Google Google Object Storage and Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Salesforce Salesforce PSF Sponsor Sentry Sentry Error logging StatusPage StatusPage Status page