Skip to main content

Utilities and baselines for fast neural network training on CIFAR-10

Project description

CIFAR-10 Airbench 💨

How to run

To perfom a fast training to 94% average accuracy on CIFAR-10, run either

git clone https://github.com/KellerJordan/cifar10-airbench.git && python airbench/airbench94.py

or

pip install airbench && python -c "import airbench; airbench.train94()"

Motivation

CIFAR-10 is among the most widely used datasets in machine learning, facilitating thousands of research projects per year. However, many studies use poorly optimized trainings, leading to wasted time and sometimes contradictory results. To resolve this problem, airbench contains a set of training methods which are both (a) very easily runnable and (b) state-of-the-art in terms of training speed.

In particular, airbench training scripts attain 94%, 95%, and 96% accuracy on the CIFAR-10 test-set in 3.29, 10.4, and 46.3 seconds on an NVIDIA A100. These methods can replace baselines like training ResNet-20 or ResNet-18.

Training methods

Script Mean accuracy Time PFLOPs
airbench94_compiled.py 94.01% 3.29s 0.36
airbench94.py 94.01% 3.83s 0.36
airbench95.py 95.01% 10.4s 1.4
airbench96.py 96.05% 46.3s 7.5

Timings are on a single NVIDIA A100. Note that the first run of training is always slower due to GPU warmup.

Using the GPU-accelerated dataloader independently

For writing custom fast CIFAR-10 training scripts, you may find GPU-accelerated dataloading useful:

import airbench
train_loader = airbench.CifarLoader('/tmp/cifar10', train=True, aug=dict(flip=True, translate=4, cutout=16), batch_size=500)
test_loader = airbench.CifarLoader('/tmp/cifar10', train=False, batch_size=1000)

for epoch in range(200):
    for inputs, labels in train_loader:
        # outputs = model(inputs)
        # loss = F.cross_entropy(outputs, labels)
        ...

If you wish to modify the data used for training, it can be done like so:

import airbench
train_loader = airbench.CifarLoader('/tmp/cifar10', train=True, aug=dict(flip=True, translate=4, cutout=16), batch_size=500)
mask = (train_loader.labels < 6)
train_loader.images = train_loader.images[mask]
train_loader.labels = train_loader.labels[mask]
print(len(train_loader)) # The loader now contains 30,000 images and has batch size 500, so this prints 60.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

airbench-0.1.2.tar.gz (10.9 kB view details)

Uploaded Source

File details

Details for the file airbench-0.1.2.tar.gz.

File metadata

  • Download URL: airbench-0.1.2.tar.gz
  • Upload date:
  • Size: 10.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.8.10

File hashes

Hashes for airbench-0.1.2.tar.gz
Algorithm Hash digest
SHA256 a5ad4dafc841694f1aca48330fb3c84f432ad55e86d18d78fcaa49065ab874f9
MD5 151b986b43fc83ceee481a947a217fb0
BLAKE2b-256 57ab050acf7480f3d1945e112d6f790a41b212aa5d8fca58a3779bf079819242

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page