Skip to main content

Implement Gradient Centralization in TensorFlow

Project description

Gradient Centralization TensorFlow Twitter

TODO: Add Pypi badge

Upload Python Package Flake8 Lint Python Version

GitHub license PEP8 GitHub stars GitHub forks GitHub watchers

This Python package implements Gradient Centralization in TensorFlow, a simple and effective optimization technique for Deep Neural Networks as suggested by Yong et al. in the paper Gradient Centralization: A New Optimization Technique for Deep Neural Networks. It can both speedup training process and improve the final generalization performance of DNNs.

Installation

Run the following to install:

pip install gradient-centralization-tf

Usage

gctf.centralized_gradients_for_optimizer

Create a centralized gradients functions for a specified optimizer.

Arguments:

  • optimizer: a tf.keras.optimizers.Optimizer object. The optimizer you are using.

Example:

>>> opt = tf.keras.optimizers.Adam(learning_rate=0.1)
>>> optimizer.get_gradients = gctf.centralized_gradients_for_optimizer(opt)
>>> model.compile(optimizer = opt, ...)

gctf.get_centralized_gradients

Computes the centralized gradients.

This function is ideally not meant to be used directly unless you are building a custom optimizer, in which case you could point get_gradients to this function. This is a modified version of tf.keras.optimizers.Optimizer.get_gradients.

Arguments:

  • optimizer: a tf.keras.optimizers.Optimizer object. The optimizer you are using.
  • loss: Scalar tensor to minimize.
  • params: List of variables.

Returns:

A gradients tensor.

gctf.optimizers

Pre built updated optimizers implementing GC.

This module is speciially built for testing out GC and in most cases you would be using gctf.centralized_gradients_for_optimizer though this module implements gctf.centralized_gradients_for_optimizer. You can directly use all optimizers with tf.keras.optimizers updated for GC.

Example:

>>> model.compile(optimizer = gctf.optimizers.adam(learning_rate = 0.01), ...)
>>> model.compile(optimizer = gctf.optimizers.rmsprop(learning_rate = 0.01, rho = 0.91), ...)
>>> model.compile(optimizer = gctf.optimizers.sgd(), ...)

Returns:

A tf.keras.optimizers.Optimizer object.

Developing gctf

To install gradient-centralization-tf, along with tools you need to develop and test, run the following in your virtualenv:

git clone git@github.com:Rishit-dagli/Gradient-Centralization-TensorFlow
# or clone your own fork

pip install -e .[dev]

License

Copyright 2020 Rishit Dagli

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gradient-centralization-tf-0.0.1.tar.gz (9.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

gradient_centralization_tf-0.0.1-py3-none-any.whl (7.3 kB view details)

Uploaded Python 3

File details

Details for the file gradient-centralization-tf-0.0.1.tar.gz.

File metadata

  • Download URL: gradient-centralization-tf-0.0.1.tar.gz
  • Upload date:
  • Size: 9.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/49.2.1 requests-toolbelt/0.9.1 tqdm/4.57.0 CPython/3.9.1

File hashes

Hashes for gradient-centralization-tf-0.0.1.tar.gz
Algorithm Hash digest
SHA256 c7b4f3fa57c9e70df021381e3a4f8fad18d4aefb32abb236c429ca5eaf231c67
MD5 95dc720ae237b76daf979cccfa54708c
BLAKE2b-256 7971a92f11f691145b5cd8bea04ec71c0cca28787345c13e0a61cf38258cc1b3

See more details on using hashes here.

File details

Details for the file gradient_centralization_tf-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: gradient_centralization_tf-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 7.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/49.2.1 requests-toolbelt/0.9.1 tqdm/4.57.0 CPython/3.9.1

File hashes

Hashes for gradient_centralization_tf-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 9719b0820d38428e5cfdab05de59af111687d13e1ed4380183bb4aebeec0997d
MD5 abcc3f443c9ac6d3ea08834cbdfe1078
BLAKE2b-256 0d6d4fab22f65cdcebddd75a365f18bb3fc7d8ff060a30272b4a3ffee5d16540

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page