Skip to main content

Dual Optimizer Training

Project description

DualOpt

Dual Optimizer Training

A variant of the SWATS training paradigm which uses two optimizers for training.

Install

$ pip install dualopt

Usage

Image Classification

import dualopt, torch
from dualopt import classification

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

num_classes = 10

#define model
#define datasets and dataloaders

top1 = []   #top1 accuracy
top5 = []   #top5 accuracy
traintime = []
testtime = []
counter = 20  # number of epochs without any improvement in accuracy before we stop training for each optimizer

PATH = 'saved_model.pth' #path to save model 

classification(model, trainloader, testloader, device, PATH, top1, top5, traintime, testtime,  num_classes = num_classes, set_counter = counter)

print('Finished Training')
print("Results")
print(f"Top 1 Accuracy: {max(top1):.2f} -Top 5 Accuracy : {max(top5):.2f} - Train Time: {min(traintime):.0f} -Test Time: {min(testtime):.0f}\n")

Post-Training

Experiments show that we get good results when training using data augmentations such as Trivial Augment. We found that subsequent post-training without using any data augmentations can further improve the results.

Usage

import dualopt, torch, torchvision
import torchvision.transforms as transforms
from dualopt import classification
from dualopt import post_train

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

num_classes = 10

#define model

#set batch size according to GPU 
batch_size = 512

# transforms
transform_train_1 = transforms.Compose(
        [ transforms.RandomHorizontalFlip(p=0.5),
            transforms.TrivialAugmentWide(),
            transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])

transform_train_2 = transforms.Compose(
        [ 
            transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])


transform_test = transforms.Compose(
        [ transforms.ToTensor(),
     transforms.Normalize((0.4941, 0.4853, 0.4507), (0.2468, 0.2430, 0.2618))])

#Dataset
trainset_1 = torchvision.datasets.CIFAR10(root='/workspace/', train=True, download=True, transform=transform_train_1)
trainloader_1 = torch.utils.data.DataLoader(trainset_1, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, prefetch_factor=2, persistent_workers=2) #trainloader with augmentations

trainset_2 = torchvision.datasets.CIFAR10(root='/workspace/', train=True, download=True, transform=transform_train_2)
trainloader_2 = torch.utils.data.DataLoader(trainset_2, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, prefetch_factor=2, persistent_workers=2) #trainloader for post-training without augmentations

testset = torchvision.datasets.CIFAR10(root='/workspace/', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True, prefetch_factor=2, persistent_workers=2)


top1 = []   #top1 accuracy
top5 = []   #top5 accuracy
traintime = []
testtime = []
counter = 20  # number of epochs without any improvement in accuracy before we stop training for each optimizer

PATH = 'saved_model.pth' #path to save model 

classification(model, trainloader_1, testloader, device, PATH, top1, top5, traintime, testtime,  num_classes = num_classes, set_counter = counter)
print('Finished Training')

model.load_state_dict(torch.load(PATH))

post_train(model, trainloader_2, testloader, device, PATH, top1, top5, traintime, testtime, num_classes = num_classes, set_counter = counter)
print('Finished Training')

print("Results")
print(f"Top 1 Accuracy: {max(top1):.2f} -Top 5 Accuracy : {max(top5):.2f} - Train Time: {min(traintime):.0f} -Test Time: {min(testtime):.0f}\n")

Cite the following paper

@misc{jeevan2022convolutional,
      title={Convolutional Xformers for Vision}, 
      author={Pranav Jeevan and Amit sethi},
      year={2022},
      eprint={2201.10271},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dualopt-0.1.7.tar.gz (4.7 kB view details)

Uploaded Source

Built Distribution

dualopt-0.1.7-py3-none-any.whl (5.1 kB view details)

Uploaded Python 3

File details

Details for the file dualopt-0.1.7.tar.gz.

File metadata

  • Download URL: dualopt-0.1.7.tar.gz
  • Upload date:
  • Size: 4.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.3

File hashes

Hashes for dualopt-0.1.7.tar.gz
Algorithm Hash digest
SHA256 ea03db1383ad61dc4015bdb3f58b178cb59d3dd70e542c2fd8599653ab478d5e
MD5 01bfbffb99e50d196b36023918273f8d
BLAKE2b-256 13065c7ecdfb734e01474fc0970f2814b4e98eb98e7d90a8e7f48c184190db07

See more details on using hashes here.

File details

Details for the file dualopt-0.1.7-py3-none-any.whl.

File metadata

  • Download URL: dualopt-0.1.7-py3-none-any.whl
  • Upload date:
  • Size: 5.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.3

File hashes

Hashes for dualopt-0.1.7-py3-none-any.whl
Algorithm Hash digest
SHA256 f4e649245c0577bb075cea1dae4d289888c4e3399047522907710420290c3c4b
MD5 3d3c420fe0f0e01bceffc05290a780ad
BLAKE2b-256 459684fc698a84ecd8bc79bbf316323e53a1e79a0fff35a2510956aaf0f3ab98

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page