A tool package for training model, pre-processing dataset and managing experiment record in pytorch AI tasks.
Project description
torchfurnace
torchfurnace
is a tool package for training model, pre-processing dataset and managing experiment record in pytorch AI tasks.
Quick Start
Usage
pip install torchfurnace
Example
trainig VGG16 for CIFAR10
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.optim.lr_scheduler import MultiStepLR
from torchfurnace import Engine, Parser
# define training process of your model
class VGGNetEngine(Engine):
@staticmethod
def _on_forward(training, model, inp, target, optimizer=None) -> dict:
ret = {'loss': object, 'preds': object}
output = model(inp)
loss = F.cross_entropy(output, target)
if training:
optimizer.zero_grad()
loss.backward()
optimizer.step()
ret['loss'] = loss
ret['preds'] = output
return ret
@staticmethod
def _get_lr_scheduler(optim) -> list:
return [MultiStepLR(optim, milestones=[150, 250, 350], gamma=0.1)]
def main():
# define experiment name
parser = Parser('TVGG16')
args = parser.parse_args()
experiment_name = '_'.join([args.dataset, args.exp_suffix])
# Data
ts = transforms.Compose([transforms.ToTensor(), transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
trainset = CIFAR10(root='data', train=True, download=True, transform=ts)
testset = CIFAR10(root='data', train=False, download=True, transform=ts)
# define model and optimizer
net = torchvision.models.vgg16(pretrained=False, num_classes=10)
net.avgpool = nn.AvgPool2d(kernel_size=1, stride=1)
net.classifier = nn.Linear(512, 10)
optimizer = torch.optim.Adam(net.parameters())
# new engine instance
eng = VGGNetEngine(parser).experiment_name(experiment_name)
acc1 = eng.learning(net, optimizer, trainset, testset)
print('Acc1:', acc1)
if __name__ == '__main__':
import sys
run_params='--dataset CIFAR10 -lr 0.1 -bs 128 -j 2 --epochs 400 --adjust_lr'
sys.argv.extend(run_params.split())
main()
Introduction
Why do this?
There are some deep learning frameworks to quickly build a training system in pytorch AI tasks, however, I found that most of them are complex framework which have higher cost for learning it and seriously invade original code , for instance, maybe modify your model class to adapt the framework.
So, torchfurnace
is born for perform your pytorch AI task quickly, simply and without invasion viz you don't have to change too much defined code.
What features?
-
torchfurnace
consists of two independent components includingengine
andtracer
.engine
is a core component of proposed framework, andtracer
is a manager of experiment whose obligation include log writing, model saving and training visualization. -
torchfurnace
integrates some practical tools, such as processing raw dataset to LMDB for solving I/O bottleneck and computing the number of parameter size.
Components
Engine
from torchfurnace import Engine
Tracer
from torchfurnace import Tracer
Parser
from torchfurnace import Parser
ImageFolderLMDB
from torchfurnace import ImageFolderLMDB
ImageLMDB
from torchfurnace import ImageLMDB
Model Summary
This tool comes from pytorch-summary.
import torchvision
from torchfurnace.utils.torch_summary import summary, summary_string
net = torchvision.models.vgg16()
# this function will output result on screen.
summary(net,(3,224,224))
# this funcion will return a string of description.
summary_string(net,(3,224,224))
Directory Architecture
TVGG16/
├── logs
│ └── CIFAR10
│ └── log.txt
├── models
│ └── CIFAR10
│ ├── architecture.txt
│ ├── checkpoint
│ │ └── best
│ └── run_config.json
└── tensorboard
└── CIFAR10
└── events.out.tfevents
Testing & Example
In this section, you have to git clone https://github.com/tianyu-su/torchfurnace.git
.
torchfurnace/tests/test_furnace.py
A unit test forEngine
.torchfurnace/tests/test_tracer.py
A unit test forTracer
.torchfurnace/tests/test_img2lmdb.py
A unit test for convert images to LMDB.torchfurnace/tests/test_vgg16.py
A compare experiment with pytorch-cifar to validate pipeline of the proposed framework.
More Usages
options.py
,flags: no_tb, p_bar, override ,ext ,exp_suffix
TODO
- training by
DistributedDataParallel
- compute mean and standard deviation of image dataset
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torchfurnace-0.0.4.tar.gz
.
File metadata
- Download URL: torchfurnace-0.0.4.tar.gz
- Upload date:
- Size: 17.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.1 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.6.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d1e0301b91f6b847b5b28d9bfcbfe44404ebe875a234a12d752f92dc88898401 |
|
MD5 | 822cab970a5a3aa8226487cd22a13035 |
|
BLAKE2b-256 | cd255981ea3ac7abe68ea02793d75347778f5cc88d2a348ebdb279ab2f138c0a |
File details
Details for the file torchfurnace-0.0.4-py3-none-any.whl
.
File metadata
- Download URL: torchfurnace-0.0.4-py3-none-any.whl
- Upload date:
- Size: 19.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.1 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.6.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 930c07ab7ac3a152364c73374d985fcefc99c4209ac3ebd07106df4525330ef1 |
|
MD5 | 70f439c5b25b6752c041cb46bfe35c1e |
|
BLAKE2b-256 | 08cf82dbef4412e3cf3aa4ff2536f4b48d3f3a34c5144f33887f4cf5f1b9793e |