Dual Optimizer Training
Project description
DualOpt
Dual Optimizer Training
A variant of the SWATS training paradigm which uses two optimizers for training.
Install
$ pip install dualopt
Usage
Image Classification
import dualopt, torch
from dualopt import classification
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_classes = 10
#define model
#define datasets and dataloaders
top1 = [] #top1 accuracy
top5 = [] #top5 accuracy
traintime = []
testtime = []
counter = 20 # number of epochs without any improvement in accuracy before we stop training for each optimizer
PATH = 'saved_model.pth' #path to save model
classification(model, trainloader, testloader, device, PATH, top1, top5, traintime, testtime, num_classes = num_classes, set_counter = counter)
print('Finished Training')
print("Results")
print(f"Top 1 Accuracy: {max(top1):.2f} -Top 5 Accuracy : {max(top5):.2f} - Train Time: {min(traintime):.0f} -Test Time: {min(testtime):.0f}\n")
Post-Training
Experiments show that we get good results when training using data augmentations such as Trivial Augment. We found that subsequent post-training without using any data augmentations can further improve the results.
Usage
import dualopt, torch, torchvision
import torchvision.transforms as transforms
from dualopt import classification
from dualopt import post_train
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_classes = 10
#define model
#set batch size according to GPU
batch_size = 512
# transforms
transform_train_1 = transforms.Compose(
[ transforms.RandomHorizontalFlip(p=0.5),
transforms.TrivialAugmentWide(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])
transform_train_2 = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])
transform_test = transforms.Compose(
[ transforms.ToTensor(),
transforms.Normalize((0.4941, 0.4853, 0.4507), (0.2468, 0.2430, 0.2618))])
#Dataset
trainset_1 = torchvision.datasets.CIFAR10(root='/workspace/', train=True, download=True, transform=transform_train_1)
trainloader_1 = torch.utils.data.DataLoader(trainset_1, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, prefetch_factor=2, persistent_workers=2) #trainloader with augmentations
trainset_2 = torchvision.datasets.CIFAR10(root='/workspace/', train=True, download=True, transform=transform_train_2)
trainloader_2 = torch.utils.data.DataLoader(trainset_2, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, prefetch_factor=2, persistent_workers=2) #trainloader for post-training without augmentations
testset = torchvision.datasets.CIFAR10(root='/workspace/', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True, prefetch_factor=2, persistent_workers=2)
top1 = [] #top1 accuracy
top5 = [] #top5 accuracy
traintime = []
testtime = []
counter = 20 # number of epochs without any improvement in accuracy before we stop training for each optimizer
PATH = 'saved_model.pth' #path to save model
classification(model, trainloader_1, testloader, device, PATH, top1, top5, traintime, testtime, num_classes = num_classes, set_counter = counter)
print('Finished Training')
model.load_state_dict(torch.load(PATH))
post_train(model, trainloader_2, testloader, device, PATH, top1, top5, traintime, testtime, num_classes = num_classes, set_counter = counter)
print('Finished Training')
print("Results")
print(f"Top 1 Accuracy: {max(top1):.2f} -Top 5 Accuracy : {max(top5):.2f} - Train Time: {min(traintime):.0f} -Test Time: {min(testtime):.0f}\n")
Cite the following paper
@misc{jeevan2022convolutional,
title={Convolutional Xformers for Vision},
author={Pranav Jeevan and Amit sethi},
year={2022},
eprint={2201.10271},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
dualopt-0.1.8.tar.gz
(4.7 kB
view details)
Built Distribution
File details
Details for the file dualopt-0.1.8.tar.gz
.
File metadata
- Download URL: dualopt-0.1.8.tar.gz
- Upload date:
- Size: 4.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e0d2b3238b8ec56eeb19c809697aefd6d3cc8f3e5c7f4f7de77056cbc82f7ca4 |
|
MD5 | bfdfb619fdcc9cecc1199f5dd63968d6 |
|
BLAKE2b-256 | badcd53d71e55ebed8da34ba26f6baffd759db53964bcad73d002a1b1fe1d166 |
File details
Details for the file dualopt-0.1.8-py3-none-any.whl
.
File metadata
- Download URL: dualopt-0.1.8-py3-none-any.whl
- Upload date:
- Size: 5.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9b40d14c28a531153b1bcd981a2b0b47268327e4fd4e24b11929d3439bfe0c01 |
|
MD5 | 7b65cab976e7827c5481cfca555663c2 |
|
BLAKE2b-256 | 5b6821bf2fc046389428e410d247c7f714fd96edc98ea7e946d44630b24232e2 |