Skip to main content

Package for gradient accumulation in TensorFlow

Project description

gradient-accumulator

GradientAccumulator

Seemless gradient accumulation for TensorFlow 2

Pip Downloads PyPI version License DOI

GradientAccumulator was developed by SINTEF Health due to the lack of an easy-to-use method for gradient accumulation in TensorFlow 2.

The package is available on PyPI and is compatible with and have been tested against TensorFlow 2.2-2.10 and Python 3.6-3.11, and works cross-platform (Ubuntu, Windows, macOS).

Continuous integration

Build Type Status
Code coverage codecov
Documentations Documentation Status
Unit tests CI

Install

Stable release from PyPI:

pip install gradient-accumulator

Or from source:

pip install git+https://github.com/andreped/GradientAccumulator

Getting started

A simple example to add gradient accumulation to an existing model is by:

from gradient_accumulator import GradientAccumulateModel
from tensorflow.keras.models import Model

model = Model(...)
model = GradientAccumulateModel(accum_steps=4, inputs=model.input, outputs=model.output)

Then simply use the model as you normally would!

In practice, using gradient accumulation with a custom pipeline might require some extra overhead and tricks to get working.

For more information, see documentations which are hosted at gradientaccumulator.readthedocs.io

What?

Gradient accumulation (GA) enables reduced GPU memory consumption through dividing a batch into smaller reduced batches, and performing gradient computation either in a distributing setting across multiple GPUs or sequentially on the same GPU. When the full batch is processed, the gradients are then accumulated to produce the full batch gradient.

Note that the very natural how we perform gradient accumulation is slightly different to avoid us needing to have the entire batch in CPU memory. More information on what goes under the hood can be seen in the documentations.

Why?

In TensorFlow 2, there did not exist a plug-and-play method to use gradient accumulation with any custom pipeline. Hence, we have implemented two generic TF2-compatible approaches:

Method Usage
GradientAccumulateModel model = GradientAccumulateModel(accum_steps=4, inputs=model.input, outputs=model.output)
GradientAccumulateOptimizer opt = GradientAccumulateOptimizer(accum_steps=4, optimizer=tf.keras.optimizers.SGD(1e-2))

Both approaches control how frequently the weigths are updated but in their own way. Approach (1) overrides the train_step method of a given Model, whereas approach (2) wraps the optimizer. (1) is only compatible with single-GPU usage, whereas (2) also supports distributed training (multi-GPU).

Our implementations enable theoretically infinitely large batch size, with identical memory consumption as for a regular mini batch. If a single GPU is used, this comes at the cost of increased training runtime. Multiple GPUs could be used to improve runtime performance.

Technique Usage
Batch Normalization layer = AccumBatchNormalization(accum_steps=4)
Adaptive Gradient Clipping model = GradientAccumulateModel(accum_steps=4, agc=True, inputs=model.input, outputs=model.output)
Mixed precision model = GradientAccumulateModel(accum_steps=4, mixed_precision=True, inputs=model.input, outputs=model.output)
  • As batch normalization (BN) is not natively compatible with GA, we have implemented a custom BN layer which can be used as a drop-in replacement.
  • Support for adaptive gradient clipping has been added as an alternative to BN.
  • Mixed precision can also be utilized on both GPUs and TPUs.
  • Multi-GPU distributed training using generic optimizer wrapper.

For more information on usage, supported techniques, and examples, refer to the documentations.

Acknowledgements

The gradient accumulator model wrapper is based on the implementation presented in this thread on stack overflow. The adaptive gradient clipping method is based on the implementation by @sayakpaul. The optimizer wrapper is derived from the implementation by @fsx950223 and @stefan-falk.

The documentations hosted here was made possible by the incredible Read The Docs team which offer free documentation hosting!

How to cite?

If you used this package or found the project relevant in your research, please, include the following citation:

@software{andre_pedersen_2023_7905351,
  author       = {André Pedersen and Tor-Arne Schmidt Nordmo and Javier Pérez de Frutos and David Bouget},
  title        = {andreped/GradientAccumulator: v0.5.0},
  month        = may,
  year         = 2023,
  publisher    = {Zenodo},
  version      = {v0.5.0},
  doi          = {10.5281/zenodo.7905351},
  url          = {https://doi.org/10.5281/zenodo.7905351}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gradient-accumulator-0.5.0.tar.gz (14.1 kB view details)

Uploaded Source

Built Distribution

gradient_accumulator-0.5.0-py3-none-any.whl (13.7 kB view details)

Uploaded Python 3

File details

Details for the file gradient-accumulator-0.5.0.tar.gz.

File metadata

  • Download URL: gradient-accumulator-0.5.0.tar.gz
  • Upload date:
  • Size: 14.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.16

File hashes

Hashes for gradient-accumulator-0.5.0.tar.gz
Algorithm Hash digest
SHA256 9a57cc9b630ad03ce92e7bbd2bef9c01aeeb79b39ab6872377d9aa00c6112640
MD5 566201e55ba5a1aa154f1f02be358827
BLAKE2b-256 6acce792ad82cdc5f3a876b33a1d0e1e5327bc5b02c7c163de0a32474824465d

See more details on using hashes here.

File details

Details for the file gradient_accumulator-0.5.0-py3-none-any.whl.

File metadata

File hashes

Hashes for gradient_accumulator-0.5.0-py3-none-any.whl
Algorithm Hash digest
SHA256 66e4d952df35c303636b53294f5a7e3839ca3cdf4f6432730fd5f6f0e369fb11
MD5 468ac68fe5304db8b7b660f296d62e1f
BLAKE2b-256 f7887cbde8a363218e1b2f35c2591635b8d46b2081825803161aa426bd76ec18

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page