Skip to main content

Adam Layer-wise LR Decay

Project description

Adam Layer-wise LR Decay

In ELECTRA, which had been published by Stanford University and Google Brain, they had used Layerwise LR Decay technique for the Adam optimizer to prevent Catastrophic forgetting of Pre-trained model.

This repo contains the implementation of Layer-wise LR Decay for Adam, with new Optimizer API that had been proposed in TensorFlow 2.11.

Usage

Installations:

$ pip install adam-lr-decay  # this method does not install tensorflow

For CPU:

$ pip install adam-lr-decay[cpu]  # this method installs tensorflow-cpu>=2.11

For GPU:

$ pip install adam-lr-decay[gpu]  # this method installs tensorflow>=2.11
from tensorflow.keras import layers, models
from adam_lr_decay import AdamLRDecay

# ... prepare training data

# model definition
model = models.Sequential([
    layers.Dense(3, input_shape=(2,), name='hidden_dense'),
    layers.Dense(1, name='output')
])

# optimizer definition with layerwise lr decay
adam = AdamLRDecay(learning_rate=1e-3)
adam.apply_layerwise_lr_decay(var_name_dicts={
    'hidden_dense': 0.1,
    'output': 0.
})
# this config decays the key layers by the value, 
# which is (lr * (1. - decay_rate))

# compile the model
model.compile(optimizer=adam)

# ... training loop

In official ELECTRA repo, they have defined the decay rate in the code. The adapted version is as follows:

import collections
from adam_lr_decay import AdamLRDecay

def _get_layer_lrs(layer_decay, n_layers):
    key_to_depths = collections.OrderedDict({
        '/embeddings/': 0,
        '/embeddings_project/': 0,
        'task_specific/': n_layers + 2,
    })
    for layer in range(n_layers):
        key_to_depths['encoder/layer_' + str(layer) + '/'] = layer + 1
    return {
        key: 1. - (layer_decay ** (n_layers + 2 - depth))
        for key, depth in key_to_depths.items()
    }

# ... ELECTRA model definition

adam = AdamLRDecay(learning_rate=1e-3)
adam.apply_layerwise_lr_decay(var_name_dicts=_get_layer_lrs(0.9, 8))

# ... custom training loop

The generated decay rate must be looked like this. 0.0 means there is no decay and 1.0 means it is zero learning rate. (non-trainable)

{
  "/embeddings/": 0.6513215599,
  "/embeddings_project/": 0.6513215599, 
  "task_specific/": 0.0, 
  "encoder/layer_0/": 0.6125795109999999, 
  "encoder/layer_1/": 0.5695327899999999, 
  "encoder/layer_2/": 0.5217030999999999, 
  "encoder/layer_3/": 0.46855899999999995, 
  "encoder/layer_4/": 0.40950999999999993, 
  "encoder/layer_5/": 0.3439, 
  "encoder/layer_6/": 0.2709999999999999, 
  "encoder/layer_7/": 0.18999999999999995
}

Citation

@article{clark2020electra,
  title={Electra: Pre-training text encoders as discriminators rather than generators},
  author={Clark, Kevin and Luong, Minh-Thang and Le, Quoc V and Manning, Christopher D},
  journal={arXiv preprint arXiv:2003.10555},
  year={2020}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

adam_lr_decay-0.0.7.tar.gz (4.7 kB view details)

Uploaded Source

Built Distribution

adam_lr_decay-0.0.7-py3-none-any.whl (5.3 kB view details)

Uploaded Python 3

File details

Details for the file adam_lr_decay-0.0.7.tar.gz.

File metadata

  • Download URL: adam_lr_decay-0.0.7.tar.gz
  • Upload date:
  • Size: 4.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.1 CPython/3.10.12 Linux/5.15.0-1042-azure

File hashes

Hashes for adam_lr_decay-0.0.7.tar.gz
Algorithm Hash digest
SHA256 fa07dcc5b19a309f2d5b61b8859a7479a809e7278c6650bc584b7c30b3148a31
MD5 ec4662756a119a3914f76fbe47553b48
BLAKE2b-256 5132b208a9f84747ed7c76c4c38e4a67724fc3041d2904c6629760fb4eec206f

See more details on using hashes here.

File details

Details for the file adam_lr_decay-0.0.7-py3-none-any.whl.

File metadata

  • Download URL: adam_lr_decay-0.0.7-py3-none-any.whl
  • Upload date:
  • Size: 5.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.1 CPython/3.10.12 Linux/5.15.0-1042-azure

File hashes

Hashes for adam_lr_decay-0.0.7-py3-none-any.whl
Algorithm Hash digest
SHA256 555a829fc6c00f09b8b14eacde1b89ce8dfdd843379b3a13c97ba1d226a3347f
MD5 e41dde84cab0d6255d20b6a14b1e1bee
BLAKE2b-256 a8c0a4580b05ded054baa84ee90a085f4f485a4e2e25133bdd6b53f8ea5f325c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page