Momentum Residual Neural Networks
Project description
This repository hosts Python code for Momentum ResNets.
See the documentation and our ICML 2021 paper.
Model
Installation
pip
To install momentumet, you first need to install its dependencies:
$ pip install numpy matplotlib numexpr scipy
Then install momentumnet with pip:
$ pip install momentumnet
or to get the latest version of the code:
$ pip install git+https://github.com/michaelsdr/momentumnet.git#egg=momentumnet
If you do not have admin privileges on the computer, use the --user flag with pip. To upgrade, use the --upgrade flag provided by pip.
check
To check if everything worked fine, you can do:
$ python -c 'import momentumnet'
and it should not give any error message.
Quickstart
The main class is MomentumNet. It creates a Momentum ResNet for which forward equations can be reversed in closed-form, enabling learning without standard memory consuming backpropagation. This process trades memory for computations.
To get started, you can create a toy Momentum ResNet by specifying the functions f for the forward pass and the value of the momentum term, gamma.
>>> from torch import nn
>>> from momentumnet import MomentumNet
>>> hidden = 8
>>> d = 500
>>> function = nn.Sequential(nn.Linear(d, hidden), nn.Tanh(), nn.Linear(hidden, d))
>>> mresnet = MomentumNet([function,] * 10, gamma=0.99)
Momentum ResNets are a drop-in replacement for ResNets
We can transform a ResNet into a MomentumNet with the same parameters in two lines of codes. For instance, the following code instantiates a Momentum ResNet with weights of a pretrained Resnet-101 on ImageNet. We set “use_backprop” to False so that activations are not saved during the forward pass, allowing smaller memory consumptions.
>>> import torch
>>> from momentumnet import transform_to_momentumnet
>>> from torchvision.models import resnet18
>>> resnet = resnet18(pretrained=True)
>>> mresnet18 = transform_to_momentumnet(resnet, gamma=0.99, use_backprop=False)
Importantly, this method also works with Pytorch Transformers module, specifying the residual layers to be turned into their Momentum version.
>>> import torch
>>> from momentumnet import transform_to_momentumnet
>>> transformer = torch.nn.Transformer(num_encoder_layers=6, num_decoder_layers=6)
>>> mtransformer = transform_to_momentumnet(transformer, residual_layers=["encoder.layers", "decoder.layers"], gamma=0.99,
>>> use_backprop=False, keep_first_layer=False)
This initiates a Momentum Transformer with the same weights as the original Transformer.
Dependencies
These are the dependencies to use momentumnet:
numpy (>=1.8)
matplotlib (>=1.3)
torch (>= 1.9)
memory_profiler
vit_pytorch
Cite
If you use this code in your project, please cite:
Michael E. Sander, Pierre Ablin, Mathieu Blondel, Gabriel Peyré Momentum Residual Neural Networks Proceedings of the 38th International Conference on Machine Learning, PMLR 139:9276-9287 https://arxiv.org/abs/2102.07870
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for momentumnet-0.9-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2e12b8b760315a3f8be7dd701931bd7f9579d96db12651c4774654f021661898 |
|
MD5 | 62898ca7feaa1afc5f86b199a20c4ea7 |
|
BLAKE2b-256 | bb45060129a5ae472bcfff84febfb459ff9131e2826f7e2ee17a236ac7eef97e |