NFNets, PyTorch
Project description
PyTorch implementation of Normalizer-Free Networks and Adaptive Gradient Clipping
Paper: https://arxiv.org/abs/2102.06171.pdf
Original code: https://github.com/deepmind/deepmind-research/tree/master/nfnets
Blog post: https://tourdeml.github.io/blog/posts/2021-03-31-adaptive-gradient-clipping/. Feel free to subscribe to the newsletter, and leave a comment if you have anything to add/suggest publicly.
Do star this repository if it helps your work, and don't forget to cite if you use this code in your research!
Installation
Install from PyPi:
pip3 install nfnets-pytorch
or install the latest code using:
pip3 install git+https://github.com/vballoli/nfnets-pytorch
Usage
WSConv2d
Use WSConv1d, WSConv2d, ScaledStdConv2d(timm)
and WSConvTranspose2d
like any other torch.nn.Conv2d
or torch.nn.ConvTranspose2d
modules.
import torch
from torch import nn
from nfnets import WSConv2d, WSConvTranspose2d, ScaledStdConv2d
conv = nn.Conv2d(3,6,3)
w_conv = WSConv2d(3,6,3)
conv_t = nn.ConvTranspose2d(3,6,3)
w_conv_t = WSConvTranspose2d(3,6,3)
Generic AGC (recommended)
import torch
from torch import nn, optim
from torchvision.models import resnet18
from nfnets import WSConv2d
from nfnets.agc import AGC # Needs testing
conv = nn.Conv2d(3,6,3)
w_conv = WSConv2d(3,6,3)
optim = optim.SGD(conv.parameters(), 1e-3)
optim_agc = AGC(conv.parameters(), optim) # Needs testing
# Ignore fc of a model while applying AGC.
model = resnet18()
optim = torch.optim.SGD(model.parameters(), 1e-3)
optim = AGC(model.parameters(), optim, model=model, ignore_agc=['fc'])
SGD - Adaptive Gradient Clipping
Similarly, use SGD_AGC
like torch.optim.SGD
# The generic AGC is preferable since the paper recommends not applying AGC to the last fc layer.
import torch
from torch import nn, optim
from nfnets import WSConv2d, SGD_AGC
conv = nn.Conv2d(3,6,3)
w_conv = WSConv2d(3,6,3)
optim = optim.SGD(conv.parameters(), 1e-3)
optim_agc = SGD_AGC(conv.parameters(), 1e-3)
Using it within any PyTorch model
replace_conv
replaces the convolution in your model with the convolution class and replaces the batchnorm with identity. While the identity is not ideal, it shouldn't cause a major difference in the latency.
import torch
from torch import nn
from torchvision.models import resnet18
from nfnets import replace_conv, WSConv2d, ScaledStdConv2d
model = resnet18()
replace_conv(model, WSConv2d) # This repo's original implementation
replace_conv(model, ScaledStdConv2d) # From timm
"""
class YourCustomClass(nn.Conv2d):
...
replace_conv(model, YourCustomClass)
"""
Docs
Find the docs at readthedocs
Cite Original Work
To cite the original paper, use:
@article{brock2021high,
author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan},
title={High-Performance Large-Scale Image Recognition Without Normalization},
journal={arXiv preprint arXiv:},
year={2021}
}
Cite this repository
To cite this repository, use:
@misc{nfnets2021pytorch,
author = {Vaibhav Balloli},
title = {A PyTorch implementation of NFNets and Adaptive Gradient Clipping},
year = {2021},
howpublished = {\url{https://github.com/vballoli/nfnets-pytorch}}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file nfnets-pytorch-0.1.3.tar.gz
.
File metadata
- Download URL: nfnets-pytorch-0.1.3.tar.gz
- Upload date:
- Size: 16.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/3.10.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.9.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4a5a692305621a40db46abb65ecb22e64492b57577dbd519022b16e238681a1b |
|
MD5 | b6abe497a3bdec61c3c5eedc333d070e |
|
BLAKE2b-256 | b5d734580b61d743e87c8fdaadce5939e95d3323c12c43c57a20f97cd14e4d7f |
File details
Details for the file nfnets_pytorch-0.1.3-py3-none-any.whl
.
File metadata
- Download URL: nfnets_pytorch-0.1.3-py3-none-any.whl
- Upload date:
- Size: 16.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/3.10.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.9.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | bfd9023642fd70c96fbfcae7615e89d1f60ac6872830125c18129d36b668a9c1 |
|
MD5 | 76188139cf60101a8c39545fd6679414 |
|
BLAKE2b-256 | 727a4fc40176d0026750fd22a8d23e27a8637372452c77cb81329455a2c4d31e |