Skip to main content

A tool package for training model, pre-processing dataset and managing experiment record in pytorch AI tasks.

Project description

torchfurnace Build Status

torchfurnace is a tool package for training model, pre-processing dataset and managing experiment record in pytorch AI tasks.

Quick Start

Usage

pip install torchfurnace

Example

trainig VGG16 for CIFAR10

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.optim.lr_scheduler import MultiStepLR
from torchfurnace import Engine, Parser

# define training process of your model
class VGGNetEngine(Engine):
    @staticmethod
    def _on_forward(training, model, inp, target, optimizer=None) -> dict:
        ret = {'loss': object, 'preds': object}
        output = model(inp)
        loss = F.cross_entropy(output, target)
        if training:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        ret['loss'] = loss
        ret['preds'] = output
        return ret

    @staticmethod
    def _get_lr_scheduler(optim) -> list:
        return [MultiStepLR(optim, milestones=[150, 250, 350], gamma=0.1)]

def main():
    # define experiment name
    parser = Parser('TVGG16')
    args = parser.parse_args()
    experiment_name = '_'.join([args.dataset, args.exp_suffix])

    # Data
    ts = transforms.Compose([transforms.ToTensor(), transforms.Normalize(
        (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    trainset = CIFAR10(root='data', train=True, download=True, transform=ts)
    testset = CIFAR10(root='data', train=False, download=True, transform=ts)

    # define model and optimizer
    net = torchvision.models.vgg16(pretrained=False, num_classes=10)
    net.avgpool = nn.AvgPool2d(kernel_size=1, stride=1)
    net.classifier = nn.Linear(512, 10)
    optimizer = torch.optim.Adam(net.parameters())

    # new engine instance
    eng = VGGNetEngine(parser).experiment_name(experiment_name)
    acc1 = eng.learning(net, optimizer, trainset, testset)
    print('Acc1:', acc1)

if __name__ == '__main__':
    import sys
    run_params='--dataset CIFAR10 -lr 0.1 -bs 128 -j 2 --epochs 400 --adjust_lr'
    sys.argv.extend(run_params.split())
    main()

Introduction

Why do this?

There are some deep learning frameworks to quickly build a training system in pytorch AI tasks, however, I found that most of them are complex framework which have higher cost for learning it and seriously invade original code , for instance, maybe modify your model class to adapt the framework.

So, torchfurnace is born for perform your pytorch AI task quickly, simply and without invasion viz you don't have to change too much defined code.

What features?

  1. torchfurnace consists of two independent components including engine and tracer. engine is a core component of proposed framework, and tracer is a manager of experiment whose obligation include log writing, model saving and training visualization.

  2. torchfurnace integrates some practical tools, such as processing raw dataset to LMDB for solving I/O bottleneck and computing the number of parameter size.

Components

Engine

from torchfurnace import Engine

Tracer

from torchfurnace import Tracer

Parser

from torchfurnace import Parser

ImageFolderLMDB

from torchfurnace import ImageFolderLMDB

ImageLMDB

from torchfurnace import ImageLMDB

Model Summary

This tool comes from pytorch-summary.

import torchvision
from torchfurnace.utils.torch_summary import summary, summary_string
net = torchvision.models.vgg16()

# this function will output result on screen.  
summary(net,(3,224,224))

# this funcion will return a string of description.
summary_string(net,(3,224,224))

Directory Architecture

TVGG16/
├── logs
│   └── CIFAR10
│       └── log.txt
├── models
│   └── CIFAR10
│       ├── architecture.txt
│       ├── checkpoint
│       │   └── best
│       └── run_config.json
└── tensorboard
    └── CIFAR10
        └── events.out.tfevents

Testing & Example

In this section, you have to git clone https://github.com/tianyu-su/torchfurnace.git.

  1. torchfurnace/tests/test_furnace.py A unit test for Engine.
  2. torchfurnace/tests/test_tracer.py A unit test for Tracer.
  3. torchfurnace/tests/test_img2lmdb.py A unit test for convert images to LMDB.
  4. torchfurnace/tests/test_vgg16.py A compare experiment with pytorch-cifar to validate pipeline of the proposed framework.

More Usages

  1. options.py,flags: no_tb, p_bar, override ,ext ,exp_suffix

TODO

  • training by DistributedDataParallel
  • compute mean and standard deviation of image dataset

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchfurnace-0.0.4.tar.gz (17.3 kB view details)

Uploaded Source

Built Distribution

torchfurnace-0.0.4-py3-none-any.whl (19.2 kB view details)

Uploaded Python 3

File details

Details for the file torchfurnace-0.0.4.tar.gz.

File metadata

  • Download URL: torchfurnace-0.0.4.tar.gz
  • Upload date:
  • Size: 17.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.1 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.6.10

File hashes

Hashes for torchfurnace-0.0.4.tar.gz
Algorithm Hash digest
SHA256 d1e0301b91f6b847b5b28d9bfcbfe44404ebe875a234a12d752f92dc88898401
MD5 822cab970a5a3aa8226487cd22a13035
BLAKE2b-256 cd255981ea3ac7abe68ea02793d75347778f5cc88d2a348ebdb279ab2f138c0a

See more details on using hashes here.

File details

Details for the file torchfurnace-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: torchfurnace-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 19.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.1 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.6.10

File hashes

Hashes for torchfurnace-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 930c07ab7ac3a152364c73374d985fcefc99c4209ac3ebd07106df4525330ef1
MD5 70f439c5b25b6752c041cb46bfe35c1e
BLAKE2b-256 08cf82dbef4412e3cf3aa4ff2536f4b48d3f3a34c5144f33887f4cf5f1b9793e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page