A library of toandaominh1997
Project description
Library of pytoan
Introduction
Installing
pip install pytoan
Usage
- Example model with MNIST
from pytoan.pytorch import Learning
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from pathlib import Path
# Hyper parameters
num_classes = 10
batch_size = 100
learning_rate = 0.001
# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data/',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = torchvision.datasets.MNIST(root='../../data/',
train=False,
transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=True)
class ConvNet(nn.Module):
def __init__(self, num_classes=10):
super(ConvNet, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.layer2 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.fc = nn.Linear(7*7*32, num_classes)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = out.reshape(out.size(0), -1)
out = self.fc(out)
return out
def accuracy_score(output, target):
with torch.no_grad():
pred = torch.argmax(output, dim=1)
assert pred.shape[0] == len(target)
correct = 0
correct += torch.sum(pred == target).item()
return correct / len(target)
model = ConvNet(num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
metric_ftns = [accuracy_score]
device = [0]
num_epoch = 100
gradient_clipping = 0.1
gradient_accumulation_steps = 1
early_stopping = 10
validation_frequency = 1
tensorboard = True
checkpoint_dir = Path('./', type(model).__name__)
checkpoint_dir.mkdir(exist_ok=True, parents=True)
resume_path = None
learning = Learning(model=model,
criterion=criterion,
optimizer=optimizer,
scheduler = scheduler,
metric_ftns=metric_ftns,
device=device,
num_epoch=num_epoch,
grad_clipping = gradient_clipping,
grad_accumulation_steps = gradient_accumulation_steps,
early_stopping = early_stopping,
validation_frequency = validation_frequency,
tensorboard = tensorboard,
checkpoint_dir = checkpoint_dir,
resume_path=resume_path)
- For Training and Validation
learning.train(train_loader, test_loader)
Log:
- For Testing
learning.test(test_loader) # but not complete
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
pytoan-0.7.2.2.9.tar.gz
(11.0 kB
view details)
Built Distribution
File details
Details for the file pytoan-0.7.2.2.9.tar.gz
.
File metadata
- Download URL: pytoan-0.7.2.2.9.tar.gz
- Upload date:
- Size: 11.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/45.2.0.post20200210 requests-toolbelt/0.9.1 tqdm/4.42.1 CPython/3.7.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 72368ff6b41a13aed451c09959c431f8a029ae91faf2ff8556e5cfc29802fe2a |
|
MD5 | 014099e1faec6a2230507c2b66fd469b |
|
BLAKE2b-256 | b1003c3c82dcb9c8579a5c65dc4aa85fc155c6d6ed3d7facd1cf6afabd8091a5 |
File details
Details for the file pytoan-0.7.2.2.9-py3-none-any.whl
.
File metadata
- Download URL: pytoan-0.7.2.2.9-py3-none-any.whl
- Upload date:
- Size: 13.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/45.2.0.post20200210 requests-toolbelt/0.9.1 tqdm/4.42.1 CPython/3.7.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | eca2d4f03da7f4b34a263c202b0f7b783079fece6d2ed71b7bfe653ca513701f |
|
MD5 | 285b1c4ae561e7646c57136ff03ab6c6 |
|
BLAKE2b-256 | 736c8114eda6865f92fd6811155c9308a69ad2af3fbe4709e3c685fcfadd59e7 |