A pytorch based deep learning solver framework.
Project description
torchsolver
A pytorch based deep learning solver framework.
**install
pip install torchsolver
example
import torch
from torch import nn, optim
from torchvision.datasets import MNIST
from torchvision.transforms import *
from torchsolver.module import Module
from torchsolver.metrics import accuracy
class LeNet(nn.Module):
def __init__(self, classes_num):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 5)
self.pool1 = nn.MaxPool2d(2, stride=2)
self.conv2 = nn.Conv2d(32, 64, 5)
self.pool2 = nn.MaxPool2d(2, stride=2)
self.act = nn.ReLU()
self.fc1 = nn.Linear(1024, 512)
self.dropout = nn.Dropout(0.5)
self.out = nn.Linear(512, classes_num)
def forward(self, x):
x = self.pool1(self.act(self.conv1(x)))
x = self.pool2(self.act(self.conv2(x)))
x = torch.flatten(x, start_dim=1)
x = self.fc1(x)
x = self.dropout(x)
x = self.out(x)
x = torch.softmax(x, dim=-1)
return x
class MnistSolver(Module):
def __init__(self, **kwargs):
super(MnistSolver, self).__init__(**kwargs)
self.model = LeNet(10)
self.loss = nn.CrossEntropyLoss()
self.optimizer = optim.Adam(self.model.parameters())
if self.num_device > 1:
self.model = torch.nn.DataParallel(self.model)
def forward(self, img, label):
pred = self.model(img)
acc = accuracy(pred, label)
if self.training:
loss = self.loss(pred, label)
return loss, {"loss": loss, "acc": acc}
else:
return acc, {}
if __name__ == '__main__':
train_data = MNIST("data", train=True, transform=ToTensor())
val_data = MNIST("data", train=False, transform=ToTensor())
MnistSolver(batch_size=128).fit(train_data=train_data, val_data=val_data)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchsolver-1.1.tar.gz
(8.8 kB
view hashes)
Built Distribution
torchsolver-1.1-py3-none-any.whl
(12.6 kB
view hashes)
Close
Hashes for torchsolver-1.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | d1d9e2dae1f22c78f2698a775a98bced6a049634999197fc3ec6b3e387e3553f |
|
MD5 | c9f9af3d3e2ddc4a7b6ceaba96d32295 |
|
BLAKE2b-256 | ac5764eb58e2e93e60e3f155b7ca7e4c15a81beb23bbe110e4054eb90477ffa7 |