Skip to main content

Simple trainer for pytorch.

Project description

Train-Pytorch

Simlified pytorch training!

PyPI project: https://pypi.org/project/train-pytorch/

The package provide:

  • A basic Trainer class to facilidate pytorch model training.
  • Some functions to compute common accuracy metrics including:
    • binary_accuracy
    • multiple_class_accuracy
    • regression_r2

You can also define your own function to input into the Trainer class as long as your function can:

  • take 2 inputs: logits and labels
  • perform all computation on: torch.tensor
  • return a python value by: value.item()

An example of our provided binary_accuracy function is:

def binary_accuracy(logits, labels, cutoff=0):
    """
    Compute binary classification accuracy score.
        return accuracy value

    Args:
        logits: logits - outputs of the model
        labels: true labels of data
        cutoff: default is 0 - model outputs logits
                can be set to 1 - if model outputs probabilities
    """
    logits, labels = logits.cpu(), labels.cpu()
    predicts = (logits > cutoff).float()
    acc = (predicts == labels).float().mean()
    return acc.item()

1. Installation

From Github:

git clone https://github.com/datngu/train_pytorch
cd train_pytorch
pip install .

From PyPI:

pip install train-pytorch

1. Example on the MNIST dataset with multiple_class_accuracy

1.1 Load your libraries

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

## import train_pytorch packages and metric functions
from train_pytorch import Trainer, binary_accuracy, multiple_class_accuracy, regression_r2

1.2 Load your dataset

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)

1.3 Buid your model

class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

1.4 Let's train it!

model = CNNModel()

## GPU: optional
#device = torch.device("cuda:0" if torch.cuda.is_available() else "mps")
#model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


trainer = Trainer(model, criterion, optimizer, multiple_class_accuracy, num_epochs = 10, early_stoper = 5)

trainer.fit(train_loader, train_loader, './output_dir')

2. Example on the sklearn breast_cancer dataset with binary_accuracy

2.1 Load your libraries

import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import load_breast_cancer
from torch.utils.data import Dataset, DataLoader

## import train_pytorch packages and metric functions
from train_pytorch import Trainer, binary_accuracy, multiple_class_accuracy, regression_r2

2.2 Load your dataset

data = load_breast_cancer()
x = data['data']
y = data['target']

sc = StandardScaler()
x = sc.fit_transform(x)

## create dataset class
class dataset(Dataset):
  def __init__(self,x,y):
    self.x = torch.tensor(x,dtype=torch.float32)
    self.y = torch.tensor(y,dtype=torch.float32)
    self.length = self.x.shape[0]
 
  def __getitem__(self,idx):
    return self.x[idx],self.y[idx]
  def __len__(self):
    return self.length
      
# a bit lazy to slipt train and test data, but it is okey for tutorial :D      
train_data = dataset(x,y)
val_data = dataset(x,y)

train_loader = DataLoader(train_data,batch_size=64,shuffle=False)
val_loader = DataLoader(val_data,batch_size=64,shuffle=False)

2.3 Buid your model

class Net(nn.Module):
  def __init__(self,input_shape):
    super(Net,self).__init__()
    self.fc1 = nn.Linear(input_shape,32)
    self.fc2 = nn.Linear(32,64)
    self.fc3 = nn.Linear(64,1)
  def forward(self,x):
    x = torch.relu(self.fc1(x))
    x = torch.relu(self.fc2(x))
    x = self.fc3(x)
    return x

2.4 Let's train it!

model = Net(input_shape=x.shape[1])

## GPU: optional
#device = torch.device("cuda:0" if torch.cuda.is_available() else "mps")
#model.to(device)

optimizer = torch.optim.SGD(model.parameters(),lr=0.1)
loss_fn = nn.BCEWithLogitsLoss()


trainer = Trainer(model, loss_fn, optimizer, binary_accuracy, num_epochs=10)
trainer.fit(train_loader, val_loader, './output_dir')

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

train_pytorch-0.0.3.tar.gz (6.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

train_pytorch-0.0.3-py3-none-any.whl (6.8 kB view details)

Uploaded Python 3

File details

Details for the file train_pytorch-0.0.3.tar.gz.

File metadata

  • Download URL: train_pytorch-0.0.3.tar.gz
  • Upload date:
  • Size: 6.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.11

File hashes

Hashes for train_pytorch-0.0.3.tar.gz
Algorithm Hash digest
SHA256 8823a225214e214ec3ca6dffc19c92af325fc03d71e955ba217d2441da0a0383
MD5 397d57768bd910239571c11960c3931b
BLAKE2b-256 265b9439d0ef8f1f53e8b2dfd0dd127a44b9198a56d4076be67f1b6a8bcd117e

See more details on using hashes here.

File details

Details for the file train_pytorch-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: train_pytorch-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 6.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.11

File hashes

Hashes for train_pytorch-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 674bc0fae1dc4ebc0556c9aa72796083c7ecc2baf8e05420a0ef6dbdc5dd4851
MD5 0f70f4ed3f145d3f15a3e0c303f4dec2
BLAKE2b-256 159b9e08d6622e1b12aaf73a20f9d18434b066bae3d1b3b8949c16b9900bdcab

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page