Smart gradient clippers
Project description
AutoClip
Pytorch and tensorflow implementations (and variations) of the AutoClip gradient smoothing procedure from Seetharaman et al.
Prem Seetharaman, Gordon Wichern, Bryan Pardo, Jonathan Le Roux. "AutoClip: Adaptive Gradient Clipping for Source Separation Networks." 2020 IEEE 30th International Workshop on Machine Learning for Signal Processing (MLSP). IEEE, 2020.
About
While training your model, AutoClip keeps a running history of all of your model's gradient magnitudes. Using these, the gradient clipper can adaptively clamp outlier gradient values before they reach the optimizer of your choice.
While AutoClip is great as a preventative measure against exploding gradients, it also speeds up training time, and encourages the optimizer to find more optimal models. At an intuitive level, AutoClip compensates for the stochastic nature of training over batches, regularizing training effects.
Installation
AutoClip is listed on pypi. To install AutoClip simply run the following command
pip install autoclip
and the autoclip
package will be installed in your currently active environment.
Torch API
Below are some examples how to use autoclip
's torch API.
Clippers as Optimizer Wrappers
Using the optimizer wrapping pattern is the recommended way to use AutoClip, and autoclip
's torch API supports wrapping arbitrary pytorch optimizers. The wrapping pattern allows you to avoid changing your training code when you want to use an AutoClip clipper. This is especially useful if you do not own the training code for whatever reason. (Say for example you are using someone else's Trainer class, as is often the case with frameworks like huggingface
.)
The following is an example of how to integrate AutoClip into your model training using this pattern:
import torch
from autoclip.torch import QuantileClip
model = torch.nn.Sequential(
torch.nn.Linear(100, 50),
torch.nn.ReLU(),
torch.nn.Linear(50, 2)
)
optimizer = torch.optim.AdamW(model.parameters())
optimizer = QuantileClip.as_optimizer(optimizer=optimizer, quantile=0.9, history_length=1000)
Now you can use the optimizer just like you would have before adding the clipper, and the clipping will be applied automatically.
Raw AutoClip Clippers
You can still use the clipper manually if you would like. If this is the case, then you would create your clipper like this:
import torch
from autoclip.torch import QuantileClip
model = torch.nn.Sequential(
torch.nn.Linear(100, 50),
torch.nn.ReLU(),
torch.nn.Linear(50, 2)
)
clipper = QuantileClip(model.parameters(), quantile=0.9, history_length=1000)
Then, to clip the model's gradients, simply run the clipper's .step()
function during your training loop. Note that you should call the clipper's step
before you call your optimizer's step
. Calling it after would mean that your clipping will have no effect, since the model will have already been updated using the unclipped gradients. For example:
for batch_num, batch in enumerate(training_dataset):
model_prediction = model(batch['data'])
loss = loss_function(model_prediction, batch['targets'])
loss.backward()
clipper.step() # clipper comes before optimizer
optimizer.step()
Global vs Local Clipping
autoclip
's torch clippers support two clipping modes. The first is global_clipping
, which is the original AutoClip as described in Seetherman et al. The second is local or parameter-wise clipping. In this mode a history is kept for every parameter, and each is clipped according to its own history. By default, the autoclip
clippers will use the parameter-wise clipping.
To use the global mode, simply pass the appropriate flag:
clipper = QuantileClip(
model.parameters(),
quantile=0.9,
history_length=1000,
global_clipping=True
)
Checkpointing
The torch clippers also support checkpointing through state_dict()
and load_state_dict()
, just like torch models and optimizers. For example, if you want to checkpoint a clipper to clipper.pth
:
clipper = QuantileClip(model.parameters())
torch.save(clipper.state_dict(), 'clipper.pth')
# Then later
clipper = QuantileClip(model.parameters())
clipper.load_state_dict(torch.load('clipper.pth'))
Keep in mind that just like a torch optimizer this will error if you give the clipper differently sized model parameters.
While it is generally recommended to use state_dict
s instead (see the pytorch documentation on this subject for more info), you may also use torch.save
and torch.load
directly to pickle the entire clipper object.
Tensorflow
autoclip
's tensorflow API does not currently have feature parity with the torch API (If you want to change this, feel free to contribute).
As it is, the tensorflow API currently only supports the original AutoClip algorithm, and does not support checkpointing. Below is a short example:
import tensorflow as tf
from autoclip.tf import QuantileClip
model = tf.keras.models.Sequential(
[
tf.keras.layers.Dense(50),
tf.keras.layers.ReLU(),
tf.keras.layers.Dense(10),
tf.keras.layers.ReLU(),
tf.keras.layers.Dense(
2,
activation=tf.keras.activations.tanh
),
]
)
model.compile(
optimizer=tf.keras.optimizers.Adam(
learning_rate=0.001,
gradient_transformers=[
QuantileClip(
quantile=0.9,
history_length=1000
)
]
),
loss="mean_absolute_error",
metrics=["accuracy"],
)
model.fit(train_data, train_targets)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file autoclip-0.2.1.tar.gz
.
File metadata
- Download URL: autoclip-0.2.1.tar.gz
- Upload date:
- Size: 11.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.10.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 85cd478ff54bb41658757916b59f352912744a8078075ed1278367d047d27de3 |
|
MD5 | 97aa491ecb0ff02f5c18e4851155b76d |
|
BLAKE2b-256 | b2c455cad7944883a4813b0f8048a5146beef2bf0f28cca660fd46173fe84900 |
File details
Details for the file autoclip-0.2.1-py3-none-any.whl
.
File metadata
- Download URL: autoclip-0.2.1-py3-none-any.whl
- Upload date:
- Size: 11.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.10.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 33682d198d0190bd6fcee1e9ad71d776ad2c6c3552b8c60a55b65d8ce6fbe4ce |
|
MD5 | 0dec9fcca7c1adf05fcb38e12b6d32ec |
|
BLAKE2b-256 | 85117c85521c95f04fbcefbd11bf1fbd84bb74cff8b60ecd722a1e2d09d1e71b |