NFNets, PyTorch
Project description
PyTorch implementation of Normalizer-Free Networks and SGD - Adaptive Gradient Clipping
Paper: https://arxiv.org/abs/2102.06171.pdf
Original code: https://github.com/deepmind/deepmind-research/tree/master/nfnets
Do star this repository if it helps your work!
Note: See this comment for a generic implementation for any optimizer as a temporary reference for anyone who needs it.
Installation
Install from PyPi:
pip3 install nfnets-pytorch
or install the latest code using:
pip3 install git+https://github.com/vballoli/nfnets-pytorch
Usage
WSConv2d
Use WSConv2d
like any other torch.nn.Conv2d
.
import torch
from torch import nn
from nfnets import WSConv2d
conv = nn.Conv2d(3,6,3)
w_conv = WSConv2d(3,6,3)
SGD - Adaptive Gradient Clipping
Similarly, use SGD_AGC
like torch.optim.SGD
import torch
from torch import nn, optim
from nfnets import WSConv2d, SGD_AGC
conv = nn.Conv2d(3,6,3)
w_conv = WSConv2d(3,6,3)
optim = optim.SGD(conv.parameters(), 1e-3)
optim_agc = SGD_AGC(conv.parameters(), 1e-3)
Generic AGC
import torch
from torch import nn, optim
from nfnets import WSConv2d
from nfnets.agc import AGC # Needs testing
conv = nn.Conv2d(3,6,3)
w_conv = WSConv2d(3,6,3)
optim = optim.SGD(conv.parameters(), 1e-3)
optim_agc = AGC(conv.parameters(), optim) # Needs testing
Using it within any PyTorch model
import torch
from torch import nn
from torchvision.models import resnet18
from nfnets import replace_conv
model = resnet18()
replace_conv(model)
Docs
Find the docs at readthedocs
TODO
- WSConv2d
- SGD - Adaptive Gradient Clipping
- Function to automatically replace Convolutions in any module with WSConv2d
- Documentation
- Generic AGC wrapper.(See this comment for a reference implementation) (Needs testing for now)
- WSConvTranspose2d
- NFNets
- NF-ResNets
Cite Original Work
To cite the original paper, use:
@article{brock2021high,
author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan},
title={High-Performance Large-Scale Image Recognition Without Normalization},
journal={arXiv preprint arXiv:},
year={2021}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
nfnets-pytorch-0.0.5.tar.gz
(10.0 kB
view hashes)
Built Distribution
Close
Hashes for nfnets_pytorch-0.0.5-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9c384cfd6a6166205efb6a8f933b9cc95c2a94f2d4cac2ac8af9f5bca5c790d6 |
|
MD5 | 3b9e5c8bfe1c04c21bb79fc6a84e33fc |
|
BLAKE2b-256 | 8dfae113d0e7a2f53ecdd2b4197c609cb197ba9feeeb5314a5674d2c3fd65a27 |