Dual Optimizer Training
Project description
DualOpt
Dual Optimizer Training
A variant of the SWATS training paradigm which uses two optimizers for training.
Install
$ pip install dualopt
Usage
Image Classification
import dualopt, torch
from dualopt import classification
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_classes = 10
#define model
#define datasets and dataloaders
top1 = [] #top1 accuracy
top5 = [] #top5 accuracy
traintime = []
testtime = []
counter = 20 # number of epochs without any improvement in accuracy before we stop training for each optimizer
PATH = 'saved_model.pth' #path to save model
classification(model, trainloader, testloader, device, PATH, top1, top5, traintime, testtime, num_classes = num_classes, set_counter = counter)
print('Finished Training')
print("Results")
print(f"Top 1 Accuracy: {max(top1):.2f} -Top 5 Accuracy : {max(top5):.2f} - Train Time: {min(traintime):.0f} -Test Time: {min(testtime):.0f}\n")
Post-Training
Experiments show that we get good results when training using data augmentations such as Trivial Augment. We found that subsequent post-training without using any data augmentations can further improve the results.
Usage
import dualopt, torch, torchvision
import torchvision.transforms as transforms
from dualopt import classification
from dualopt import post_train
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_classes = 10
#define model
#set batch size according to GPU
batch_size = 512
# transforms
transform_train_1 = transforms.Compose(
[ transforms.RandomHorizontalFlip(p=0.5),
transforms.TrivialAugmentWide(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])
transform_train_2 = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])
transform_test = transforms.Compose(
[ transforms.ToTensor(),
transforms.Normalize((0.4941, 0.4853, 0.4507), (0.2468, 0.2430, 0.2618))])
#Dataset
trainset_1 = torchvision.datasets.CIFAR10(root='/workspace/', train=True, download=True, transform=transform_train_1)
trainloader_1 = torch.utils.data.DataLoader(trainset_1, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, prefetch_factor=2, persistent_workers=2) #trainloader with augmentations
trainset_2 = torchvision.datasets.CIFAR10(root='/workspace/', train=True, download=True, transform=transform_train_2)
trainloader_2 = torch.utils.data.DataLoader(trainset_2, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, prefetch_factor=2, persistent_workers=2) #trainloader for post-training without augmentations
testset = torchvision.datasets.CIFAR10(root='/workspace/', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True, prefetch_factor=2, persistent_workers=2)
top1 = [] #top1 accuracy
top5 = [] #top5 accuracy
traintime = []
testtime = []
counter = 20 # number of epochs without any improvement in accuracy before we stop training for each optimizer
PATH = 'saved_model.pth' #path to save model
classification(model, trainloader_1, testloader, device, PATH, top1, top5, traintime, testtime, num_classes = num_classes, set_counter = counter)
print('Finished Training')
model.load_state_dict(torch.load(PATH))
post_train(model, trainloader_2, testloader, device, PATH, top1, top5, traintime, testtime, num_classes = num_classes, set_counter = counter)
print('Finished Training')
print("Results")
print(f"Top 1 Accuracy: {max(top1):.2f} -Top 5 Accuracy : {max(top5):.2f} - Train Time: {min(traintime):.0f} -Test Time: {min(testtime):.0f}\n")
Cite the following paper
@misc{jeevan2022convolutional,
title={Convolutional Xformers for Vision},
author={Pranav Jeevan and Amit sethi},
year={2022},
eprint={2201.10271},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
dualopt-0.1.7.tar.gz
(4.7 kB
view details)
Built Distribution
File details
Details for the file dualopt-0.1.7.tar.gz
.
File metadata
- Download URL: dualopt-0.1.7.tar.gz
- Upload date:
- Size: 4.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ea03db1383ad61dc4015bdb3f58b178cb59d3dd70e542c2fd8599653ab478d5e |
|
MD5 | 01bfbffb99e50d196b36023918273f8d |
|
BLAKE2b-256 | 13065c7ecdfb734e01474fc0970f2814b4e98eb98e7d90a8e7f48c184190db07 |
File details
Details for the file dualopt-0.1.7-py3-none-any.whl
.
File metadata
- Download URL: dualopt-0.1.7-py3-none-any.whl
- Upload date:
- Size: 5.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | f4e649245c0577bb075cea1dae4d289888c4e3399047522907710420290c3c4b |
|
MD5 | 3d3c420fe0f0e01bceffc05290a780ad |
|
BLAKE2b-256 | 459684fc698a84ecd8bc79bbf316323e53a1e79a0fff35a2510956aaf0f3ab98 |