Skip to main content

Dual Optimizer Training

Project description

DualOpt

Dual Optimizer Training

A variant of the SWATS training paradigm which uses two optimizers for training.

Install

$ pip install dualopt

Usage

Image Classification

import dualopt, torch
from dualopt import classification

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

num_classes = 10

#define model
#define datasets and dataloaders

top1 = []   #top1 accuracy
top5 = []   #top5 accuracy
traintime = []
testtime = []
counter = 20  # number of epochs without any improvement in accuracy before we stop training for each optimizer

PATH = 'saved_model.pth' #path to save model 

classification(model, trainloader, testloader, device, PATH, top1, top5, traintime, testtime,  num_classes = num_classes, set_counter = counter)

print('Finished Training')
print("Results")
print(f"Top 1 Accuracy: {max(top1):.2f} -Top 5 Accuracy : {max(top5):.2f} - Train Time: {min(traintime):.0f} -Test Time: {min(testtime):.0f}\n")

Post-Training

Experiments show that we get good results when training using data augmentations such as Trivial Augment. We found that subsequent post-training without using any data augmentations can further improve the results.

Usage

import dualopt, torch, torchvision
import torchvision.transforms as transforms
from dualopt import classification
from dualopt import post_train

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

num_classes = 10

#define model

#set batch size according to GPU 
batch_size = 512

# transforms
transform_train_1 = transforms.Compose(
        [ transforms.RandomHorizontalFlip(p=0.5),
            transforms.TrivialAugmentWide(),
            transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])

transform_train_2 = transforms.Compose(
        [ 
            transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])


transform_test = transforms.Compose(
        [ transforms.ToTensor(),
     transforms.Normalize((0.4941, 0.4853, 0.4507), (0.2468, 0.2430, 0.2618))])

#Dataset
trainset_1 = torchvision.datasets.CIFAR10(root='/workspace/', train=True, download=True, transform=transform_train_1)
trainloader_1 = torch.utils.data.DataLoader(trainset_1, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, prefetch_factor=2, persistent_workers=2) #trainloader with augmentations

trainset_2 = torchvision.datasets.CIFAR10(root='/workspace/', train=True, download=True, transform=transform_train_2)
trainloader_2 = torch.utils.data.DataLoader(trainset_2, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, prefetch_factor=2, persistent_workers=2) #trainloader for post-training without augmentations

testset = torchvision.datasets.CIFAR10(root='/workspace/', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True, prefetch_factor=2, persistent_workers=2)


top1 = []   #top1 accuracy
top5 = []   #top5 accuracy
traintime = []
testtime = []
counter = 20  # number of epochs without any improvement in accuracy before we stop training for each optimizer

PATH = 'saved_model.pth' #path to save model 

classification(model, trainloader_1, testloader, device, PATH, top1, top5, traintime, testtime,  num_classes = num_classes, set_counter = counter)
print('Finished Training')

model.load_state_dict(torch.load(PATH))

post_train(model, trainloader_2, testloader, device, PATH, top1, top5, traintime, testtime, num_classes = num_classes, set_counter = counter)
print('Finished Training')

print("Results")
print(f"Top 1 Accuracy: {max(top1):.2f} -Top 5 Accuracy : {max(top5):.2f} - Train Time: {min(traintime):.0f} -Test Time: {min(testtime):.0f}\n")

Cite the following paper

@misc{jeevan2022convolutional,
      title={Convolutional Xformers for Vision}, 
      author={Pranav Jeevan and Amit sethi},
      year={2022},
      eprint={2201.10271},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dualopt-0.1.8.tar.gz (4.7 kB view details)

Uploaded Source

Built Distribution

dualopt-0.1.8-py3-none-any.whl (5.0 kB view details)

Uploaded Python 3

File details

Details for the file dualopt-0.1.8.tar.gz.

File metadata

  • Download URL: dualopt-0.1.8.tar.gz
  • Upload date:
  • Size: 4.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.3

File hashes

Hashes for dualopt-0.1.8.tar.gz
Algorithm Hash digest
SHA256 e0d2b3238b8ec56eeb19c809697aefd6d3cc8f3e5c7f4f7de77056cbc82f7ca4
MD5 bfdfb619fdcc9cecc1199f5dd63968d6
BLAKE2b-256 badcd53d71e55ebed8da34ba26f6baffd759db53964bcad73d002a1b1fe1d166

See more details on using hashes here.

File details

Details for the file dualopt-0.1.8-py3-none-any.whl.

File metadata

  • Download URL: dualopt-0.1.8-py3-none-any.whl
  • Upload date:
  • Size: 5.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.3

File hashes

Hashes for dualopt-0.1.8-py3-none-any.whl
Algorithm Hash digest
SHA256 9b40d14c28a531153b1bcd981a2b0b47268327e4fd4e24b11929d3439bfe0c01
MD5 7b65cab976e7827c5481cfca555663c2
BLAKE2b-256 5b6821bf2fc046389428e410d247c7f714fd96edc98ea7e946d44630b24232e2

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page