Skip to main content

Package for A PyTorch Extension by NVIDIA

Project description

Introduction

This is a Python package available on PyPI for NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch. Some of the code here will be included in upstream Pytorch eventually. The intention of Apex is to make up-to-date utilities available to users as quickly as possible.

Full API Documentation: https://nvidia.github.io/apex

GTC 2019 and Pytorch DevCon 2019 Slides

Contents

1. Amp: Automatic Mixed Precision

apex.amp is a tool to enable mixed precision training by changing only 3 lines of your script. Users can easily experiment with different pure and mixed precision training modes by supplying different flags to amp.initialize.

Webinar introducing Amp (The flag cast_batchnorm has been renamed to keep_batchnorm_fp32).

API Documentation

Comprehensive Imagenet example

DCGAN example coming soon...

Moving to the new Amp API (for users of the deprecated "Amp" and "FP16_Optimizer" APIs)

2. Distributed Training

apex.parallel.DistributedDataParallel is a module wrapper, similar to torch.nn.parallel.DistributedDataParallel. It enables convenient multiprocess distributed training, optimized for NVIDIA's NCCL communication library.

API Documentation

Python Source

Example/Walkthrough

The Imagenet example shows use of apex.parallel.DistributedDataParallel along with apex.amp.

Synchronized Batch Normalization

apex.parallel.SyncBatchNorm extends torch.nn.modules.batchnorm._BatchNorm to support synchronized BN. It allreduces stats across processes during multiprocess (DistributedDataParallel) training. Synchronous BN has been used in cases where only a small local minibatch can fit on each GPU. Allreduced stats increase the effective batch size for the BN layer to the global batch size across all processes (which, technically, is the correct formulation). Synchronous BN has been observed to improve converged accuracy in some of our research models.

Checkpointing

To properly save and load your amp training, we introduce the amp.state_dict(), which contains all loss_scalers and their corresponding unskipped steps, as well as amp.load_state_dict() to restore these attributes.

In order to get bitwise accuracy, we recommend the following workflow:

# Initialization
opt_level = 'O1'
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)

# Train your model
...
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()
...

# Save checkpoint
checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'amp': amp.state_dict()
}
torch.save(checkpoint, 'amp_checkpoint.pt')
...

# Restore
model = ...
optimizer = ...
checkpoint = torch.load('amp_checkpoint.pt')

model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
amp.load_state_dict(checkpoint['amp'])

# Continue training
...

Note that we recommend restoring the model using the same opt_level. Also note that we recommend calling the load_state_dict methods after amp.initialize.

Requirements

Python 3

CUDA 9 or newer

PyTorch 0.4 or newer. The CUDA and C++ extensions require pytorch 1.0 or newer.

Quick Start

Linux

For performance and full functionality, we recommend installing with CUDA and C++ extensions according to

pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" pytorch-extension

For a Python-only build (required with Pytorch 0.4):

pip install -v --disable-pip-version-check --no-cache-dir pytorch-extension

A Python-only build omits:

  • Fused kernels required to use apex.optimizers.FusedAdam.
  • Fused kernels required to use apex.normalization.FusedLayerNorm.
  • Fused kernels that improve the performance and numerical stability of apex.parallel.SyncBatchNorm.
  • Fused kernels that improve the performance of apex.parallel.DistributedDataParallel and apex.amp. DistributedDataParallel, amp, and SyncBatchNorm will still be usable, but they may be slower.

Pyprof support has been moved to its own dedicated repository. The codebase is deprecated in Apex and will be removed soon.

Windows support

Windows support is experimental, and Linux is recommended. pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" pytorch-extension may work if you were able to build Pytorch from source on your system. pip install -v --disable-pip-version-check --no-cache-dir pytorch-extension (without CUDA/C++ extensions) is more likely to work. If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch-extension-0.2.tar.gz (603.5 kB view details)

Uploaded Source

Built Distribution

pytorch_extension-0.2-py3-none-any.whl (214.0 kB view details)

Uploaded Python 3

File details

Details for the file pytorch-extension-0.2.tar.gz.

File metadata

  • Download URL: pytorch-extension-0.2.tar.gz
  • Upload date:
  • Size: 603.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.6.4 pkginfo/1.7.1 requests/2.23.0 requests-toolbelt/0.9.1 tqdm/4.62.0 CPython/3.7.11

File hashes

Hashes for pytorch-extension-0.2.tar.gz
Algorithm Hash digest
SHA256 08dde16be2ef298a89bf5cbc3c3faeaf2e7f57c8713f56394b6273851fa9c2cb
MD5 b770a2b6d2ebf668f5f9da8de22b0aef
BLAKE2b-256 a8c6335ca28c5375e2bd1e3052cd06bd14b4d2786d38f3fb577251ebdf04d350

See more details on using hashes here.

File details

Details for the file pytorch_extension-0.2-py3-none-any.whl.

File metadata

  • Download URL: pytorch_extension-0.2-py3-none-any.whl
  • Upload date:
  • Size: 214.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.6.4 pkginfo/1.7.1 requests/2.23.0 requests-toolbelt/0.9.1 tqdm/4.62.0 CPython/3.7.11

File hashes

Hashes for pytorch_extension-0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 1be1276d14b5bee090cfcbb24847db5927b0e7ee7f68ace73431a611629db42c
MD5 6b22a969e0d28b90a23141004595db11
BLAKE2b-256 1eff350e1479141e72dcc6cdb358c4ba080cd645fb6d8f5c7e46f8f3b382b4d7

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page