Skip to main content

NFNets, PyTorch

Project description

PyTorch implementation of Normalizer-Free Networks and Adaptive Gradient Clipping

Python Package Docs

Paper: https://arxiv.org/abs/2102.06171.pdf

Original code: https://github.com/deepmind/deepmind-research/tree/master/nfnets

Blog post: https://tourdeml.github.io/blog/posts/2021-03-31-adaptive-gradient-clipping/. Feel free to subscribe to the newsletter, and leave a comment if you have anything to add/suggest publicly.

Do star this repository if it helps your work, and don't forget to cite if you use this code in your research!

Installation

Install from PyPi:

pip3 install nfnets-pytorch

or install the latest code using:

pip3 install git+https://github.com/vballoli/nfnets-pytorch

Usage

WSConv2d

Use WSConv1d, WSConv2d, ScaledStdConv2d(timm) and WSConvTranspose2d like any other torch.nn.Conv2d or torch.nn.ConvTranspose2d modules.

import torch
from torch import nn
from nfnets import WSConv2d, WSConvTranspose2d, ScaledStdConv2d

conv = nn.Conv2d(3,6,3)
w_conv = WSConv2d(3,6,3)

conv_t = nn.ConvTranspose2d(3,6,3)
w_conv_t = WSConvTranspose2d(3,6,3)

Generic AGC (recommended)

import torch
from torch import nn, optim
from torchvision.models import resnet18

from nfnets import WSConv2d
from nfnets.agc import AGC # Needs testing

conv = nn.Conv2d(3,6,3)
w_conv = WSConv2d(3,6,3)

optim = optim.SGD(conv.parameters(), 1e-3)
optim_agc = AGC(conv.parameters(), optim) # Needs testing

# Ignore fc of a model while applying AGC.
model = resnet18()
optim = torch.optim.SGD(model.parameters(), 1e-3)
optim = AGC(model.parameters(), optim, model=model, ignore_agc=['fc'])

SGD - Adaptive Gradient Clipping

Similarly, use SGD_AGC like torch.optim.SGD

# The generic AGC is preferable since the paper recommends not applying AGC to the last fc layer.
import torch
from torch import nn, optim
from nfnets import WSConv2d, SGD_AGC

conv = nn.Conv2d(3,6,3)
w_conv = WSConv2d(3,6,3)

optim = optim.SGD(conv.parameters(), 1e-3)
optim_agc = SGD_AGC(conv.parameters(), 1e-3)

Using it within any PyTorch model

replace_conv replaces the convolution in your model with the convolution class and replaces the batchnorm with identity. While the identity is not ideal, it shouldn't cause a major difference in the latency.

import torch
from torch import nn
from torchvision.models import resnet18

from nfnets import replace_conv, WSConv2d, ScaledStdConv2d

model = resnet18()
replace_conv(model, WSConv2d) # This repo's original implementation
replace_conv(model, ScaledStdConv2d) # From timm

"""
class YourCustomClass(nn.Conv2d):
  ...
replace_conv(model, YourCustomClass)
"""

Docs

Find the docs at readthedocs

Cite Original Work

To cite the original paper, use:

@article{brock2021high,
  author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan},
  title={High-Performance Large-Scale Image Recognition Without Normalization},
  journal={arXiv preprint arXiv:},
  year={2021}
}

Cite this repository

To cite this repository, use:

@misc{nfnets2021pytorch,
  author = {Vaibhav Balloli},
  title = {A PyTorch implementation of NFNets and Adaptive Gradient Clipping},
  year = {2021},
  howpublished = {\url{https://github.com/vballoli/nfnets-pytorch}}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nfnets-pytorch-0.1.3.tar.gz (16.3 kB view details)

Uploaded Source

Built Distribution

nfnets_pytorch-0.1.3-py3-none-any.whl (16.8 kB view details)

Uploaded Python 3

File details

Details for the file nfnets-pytorch-0.1.3.tar.gz.

File metadata

  • Download URL: nfnets-pytorch-0.1.3.tar.gz
  • Upload date:
  • Size: 16.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.10.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.9.2

File hashes

Hashes for nfnets-pytorch-0.1.3.tar.gz
Algorithm Hash digest
SHA256 4a5a692305621a40db46abb65ecb22e64492b57577dbd519022b16e238681a1b
MD5 b6abe497a3bdec61c3c5eedc333d070e
BLAKE2b-256 b5d734580b61d743e87c8fdaadce5939e95d3323c12c43c57a20f97cd14e4d7f

See more details on using hashes here.

File details

Details for the file nfnets_pytorch-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: nfnets_pytorch-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 16.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.10.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.9.2

File hashes

Hashes for nfnets_pytorch-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 bfd9023642fd70c96fbfcae7615e89d1f60ac6872830125c18129d36b668a9c1
MD5 76188139cf60101a8c39545fd6679414
BLAKE2b-256 727a4fc40176d0026750fd22a8d23e27a8637372452c77cb81329455a2c4d31e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page