Dual Optimizer Training
Project description
DualOpt
Dual Optimizer Training
A variant of the SWATS training paradigm which uses two optimizers for training.
Install
$ pip install dualopt
Usage
Image Classification
import dualopt, torch
from dualopt import classification
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_classes = 10
#define model
#define datasets and dataloaders
top1 = [] #top1 accuracy
top5 = [] #top5 accuracy
traintime = []
testtime = []
counter = 20 # number of epochs without any improvement in accuracy before we stop training for each optimizer
PATH = 'saved_model.pth' #path to save model
classification(model, trainloader, testloader, device, PATH, top1, top5, traintime, testtime, num_classes = num_classes, set_counter = counter)
print('Finished Training')
print("Results")
print(f"Top 1 Accuracy: {max(top1):.2f} -Top 5 Accuracy : {max(top5):.2f} - Train Time: {min(traintime):.0f} -Test Time: {min(testtime):.0f}\n")
Post-Training
Experiments show that we get good results when training using data augmentations such as Trivial Augment. We found that subsequent post-training without using any data augmentations can further improve the results.
Usage
import dualopt, torch, torchvision
import torchvision.transforms as transforms
from dualopt import classification
from dualopt import post_train
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_classes = 10
#define model
#set batch size according to GPU
batch_size = 512
# transforms
transform_train_1 = transforms.Compose(
[ transforms.RandomHorizontalFlip(p=0.5),
transforms.TrivialAugmentWide(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])
transform_train_2 = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])
transform_test = transforms.Compose(
[ transforms.ToTensor(),
transforms.Normalize((0.4941, 0.4853, 0.4507), (0.2468, 0.2430, 0.2618))])
#Dataset
trainset_1 = torchvision.datasets.CIFAR10(root='/workspace/', train=True, download=True, transform=transform_train_1)
trainloader_1 = torch.utils.data.DataLoader(trainset_1, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, prefetch_factor=2, persistent_workers=2) #trainloader with augmentations
trainset_2 = torchvision.datasets.CIFAR10(root='/workspace/', train=True, download=True, transform=transform_train_2)
trainloader_2 = torch.utils.data.DataLoader(trainset_2, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, prefetch_factor=2, persistent_workers=2) #trainloader for post-training without augmentations
testset = torchvision.datasets.CIFAR10(root='/workspace/', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True, prefetch_factor=2, persistent_workers=2)
top1 = [] #top1 accuracy
top5 = [] #top5 accuracy
traintime = []
testtime = []
counter = 20 # number of epochs without any improvement in accuracy before we stop training for each optimizer
PATH = 'saved_model.pth' #path to save model
classification(model, trainloader_1, testloader, device, PATH, top1, top5, traintime, testtime, num_classes = num_classes, set_counter = counter)
print('Finished Training')
model.load_state_dict(torch.load(PATH))
post_train(model, trainloader_2, testloader, device, PATH, top1, top5, traintime, testtime, num_classes = num_classes, set_counter = counter)
print('Finished Training')
print("Results")
print(f"Top 1 Accuracy: {max(top1):.2f} -Top 5 Accuracy : {max(top5):.2f} - Train Time: {min(traintime):.0f} -Test Time: {min(testtime):.0f}\n")
Cite the following paper
@misc{jeevan2022convolutional,
title={Convolutional Xformers for Vision},
author={Pranav Jeevan and Amit sethi},
year={2022},
eprint={2201.10271},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file dualopt-0.1.8.tar.gz.
File metadata
- Download URL: dualopt-0.1.8.tar.gz
- Upload date:
- Size: 4.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e0d2b3238b8ec56eeb19c809697aefd6d3cc8f3e5c7f4f7de77056cbc82f7ca4
|
|
| MD5 |
bfdfb619fdcc9cecc1199f5dd63968d6
|
|
| BLAKE2b-256 |
badcd53d71e55ebed8da34ba26f6baffd759db53964bcad73d002a1b1fe1d166
|
File details
Details for the file dualopt-0.1.8-py3-none-any.whl.
File metadata
- Download URL: dualopt-0.1.8-py3-none-any.whl
- Upload date:
- Size: 5.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9b40d14c28a531153b1bcd981a2b0b47268327e4fd4e24b11929d3439bfe0c01
|
|
| MD5 |
7b65cab976e7827c5481cfca555663c2
|
|
| BLAKE2b-256 |
5b6821bf2fc046389428e410d247c7f714fd96edc98ea7e946d44630b24232e2
|