Skip to main content

A pytorch based deep learning solver framework.

Project description

torchsolver

A pytorch based deep learning solver framework.

install

pip install torchsolver

example

import torch
from torch import nn, optim
from torchvision.datasets import MNIST
from torchvision.transforms import *

from torchsolver.module import Module
from torchsolver.metrics import accuracy


class LeNet(nn.Module):
    def __init__(self, classes_num):
        super(LeNet, self).__init__()

        self.conv1 = nn.Conv2d(1, 32, 5)
        self.pool1 = nn.MaxPool2d(2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.pool2 = nn.MaxPool2d(2, stride=2)

        self.act = nn.ReLU()

        self.fc1 = nn.Linear(1024, 512)
        self.dropout = nn.Dropout(0.5)
        self.out = nn.Linear(512, classes_num)

    def forward(self, x):
        x = self.pool1(self.act(self.conv1(x)))
        x = self.pool2(self.act(self.conv2(x)))

        x = torch.flatten(x, start_dim=1)

        x = self.fc1(x)
        x = self.dropout(x)
        x = self.out(x)

        x = torch.softmax(x, dim=-1)
        return x


class MnistSolver(Module):
    def __init__(self, **kwargs):
        super(MnistSolver, self).__init__(**kwargs)

        self.model = LeNet(10)
        self.loss = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters())

        if self.num_device > 1:
            self.model = torch.nn.DataParallel(self.model)

    def forward(self, img, label):
        pred = self.model(img)

        acc = accuracy(pred, label)
        if self.training:
            loss = self.loss(pred, label)
            return loss, {"loss": loss, "acc": acc}
        else:
            return acc, {}


if __name__ == '__main__':
    train_data = MNIST("data", train=True, transform=ToTensor())
    val_data = MNIST("data", train=False, transform=ToTensor())

    MnistSolver(batch_size=128).fit(train_data=train_data, val_data=val_data)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchsolver-1.5.1.tar.gz (12.2 kB view details)

Uploaded Source

Built Distribution

torchsolver-1.5.1-py3-none-any.whl (16.6 kB view details)

Uploaded Python 3

File details

Details for the file torchsolver-1.5.1.tar.gz.

File metadata

  • Download URL: torchsolver-1.5.1.tar.gz
  • Upload date:
  • Size: 12.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.51.0 CPython/3.8.5

File hashes

Hashes for torchsolver-1.5.1.tar.gz
Algorithm Hash digest
SHA256 565965813e9c2f165d9624cd946decabc8b01b54e2ceb7ccfb9519ccfb584492
MD5 6b0086a95217ad68a71d6895f6a1ec1c
BLAKE2b-256 788e7c1288190b4004968fa088bd19042032c29ac134be68d5f8de1ec9e87d43

See more details on using hashes here.

File details

Details for the file torchsolver-1.5.1-py3-none-any.whl.

File metadata

  • Download URL: torchsolver-1.5.1-py3-none-any.whl
  • Upload date:
  • Size: 16.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.51.0 CPython/3.8.5

File hashes

Hashes for torchsolver-1.5.1-py3-none-any.whl
Algorithm Hash digest
SHA256 386c208109314cb66313036ea0dd7bd73446433cbf0dd2977b7f46642134f707
MD5 697a0602666fbbf670538e79324a233d
BLAKE2b-256 62cc686a3b297ae39e03fc4c5b4920905257787be07ba1f14a38bbb42b085743

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page