A pytorch based deep learning solver framework.
Project description
torchsolver
A pytorch based deep learning solver framework.
install
pip install torchsolver
example
import torch
from torch import nn, optim
from torchvision.datasets import MNIST
from torchvision.transforms import *
from torchsolver.module import Module
from torchsolver.metrics import accuracy
class LeNet(nn.Module):
def __init__(self, classes_num):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 5)
self.pool1 = nn.MaxPool2d(2, stride=2)
self.conv2 = nn.Conv2d(32, 64, 5)
self.pool2 = nn.MaxPool2d(2, stride=2)
self.act = nn.ReLU()
self.fc1 = nn.Linear(1024, 512)
self.dropout = nn.Dropout(0.5)
self.out = nn.Linear(512, classes_num)
def forward(self, x):
x = self.pool1(self.act(self.conv1(x)))
x = self.pool2(self.act(self.conv2(x)))
x = torch.flatten(x, start_dim=1)
x = self.fc1(x)
x = self.dropout(x)
x = self.out(x)
x = torch.softmax(x, dim=-1)
return x
class MnistSolver(Module):
def __init__(self, **kwargs):
super(MnistSolver, self).__init__(**kwargs)
self.model = LeNet(10)
self.loss = nn.CrossEntropyLoss()
self.optimizer = optim.Adam(self.model.parameters())
if self.num_device > 1:
self.model = torch.nn.DataParallel(self.model)
def forward(self, img, label):
pred = self.model(img)
acc = accuracy(pred, label)
if self.training:
loss = self.loss(pred, label)
return loss, {"loss": loss, "acc": acc}
else:
return acc, {}
if __name__ == '__main__':
train_data = MNIST("data", train=True, transform=ToTensor())
val_data = MNIST("data", train=False, transform=ToTensor())
MnistSolver(batch_size=128).fit(train_data=train_data, val_data=val_data)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchsolver-1.5.1.tar.gz
(12.2 kB
view details)
Built Distribution
File details
Details for the file torchsolver-1.5.1.tar.gz
.
File metadata
- Download URL: torchsolver-1.5.1.tar.gz
- Upload date:
- Size: 12.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.51.0 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 565965813e9c2f165d9624cd946decabc8b01b54e2ceb7ccfb9519ccfb584492 |
|
MD5 | 6b0086a95217ad68a71d6895f6a1ec1c |
|
BLAKE2b-256 | 788e7c1288190b4004968fa088bd19042032c29ac134be68d5f8de1ec9e87d43 |
File details
Details for the file torchsolver-1.5.1-py3-none-any.whl
.
File metadata
- Download URL: torchsolver-1.5.1-py3-none-any.whl
- Upload date:
- Size: 16.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.51.0 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 386c208109314cb66313036ea0dd7bd73446433cbf0dd2977b7f46642134f707 |
|
MD5 | 697a0602666fbbf670538e79324a233d |
|
BLAKE2b-256 | 62cc686a3b297ae39e03fc4c5b4920905257787be07ba1f14a38bbb42b085743 |