Momentum Residual Neural Networks
Project description
This repository hosts Python code for Momentum ResNets.
See the documentation, our ICML 2021 paper and a 5 min presentation.
Model
Official library for using Momentum Residual Neural Networks [1]. These models extend any Residual architecture (for instance it also work with Transformers) to a larger class of deep learning models that consume less memory. They can be initialized with the same weights as a pretrained ResNet and are promising in fine-tuning applications.
Installation
pip
To install momentumet, you first need to install its dependencies:
$ pip install numpy matplotlib torch
Then install momentumnet with pip:
$ pip install momentumnet
or to get the latest version of the code:
$ pip install git+https://github.com/michaelsdr/momentumnet.git#egg=momentumnet
If you do not have admin privileges on the computer, use the --user flag with pip. To upgrade, use the --upgrade flag provided by pip.
check
To check if everything worked fine, you can do:
$ python -c 'import momentumnet'
and it should not give any error message.
Quickstart
The main class is MomentumNet. It creates a Momentum ResNet for which forward equations can be reversed in closed-form, enabling learning without standard memory consuming backpropagation. This process trades memory for computations.
To get started, you can create a toy Momentum ResNet by specifying the functions f for the forward pass and the value of the momentum term, gamma.
>>> from torch import nn
>>> from momentumnet import MomentumNet
>>> hidden = 8
>>> d = 500
>>> function = nn.Sequential(nn.Linear(d, hidden), nn.Tanh(), nn.Linear(hidden, d))
>>> mresnet = MomentumNet([function,] * 10, gamma=0.9)
Momentum ResNets are a drop-in replacement for ResNets
We can transform a ResNet into a MomentumNet with the same parameters in two lines of codes. For instance, the following code instantiates a Momentum ResNet with weights of a pretrained Resnet-101 on ImageNet. We set “use_backprop” to False so that activations are not saved during the forward pass, allowing smaller memory consumptions.
>>> import torch
>>> from momentumnet import transform_to_momentumnet
>>> from torchvision.models import resnet101
>>> resnet = resnet101(pretrained=True)
>>> mresnet101 = transform_to_momentumnet(resnet, gamma=0.9, use_backprop=False)
Importantly, this method also works with Pytorch Transformers module, specifying the residual layers to be turned into their Momentum version.
>>> import torch
>>> from momentumnet import transform_to_momentumnet
>>> transformer = torch.nn.Transformer(num_encoder_layers=6, num_decoder_layers=6)
>>> mtransformer = transform_to_momentumnet(transformer, sub_layers=["encoder.layers", "decoder.layers"], gamma=0.9,
>>> use_backprop=False, keep_first_layer=False)
This initiates a Momentum Transformer with the same weights as the original Transformer.
Memory savings when applying Momentum ResNets to Transformers
Here is a short tutorial showing the memory gains when using Momentum Transformers.
Dependencies
These are the dependencies to use momentumnet:
numpy (>=1.8)
matplotlib (>=1.3)
torch (>= 1.9)
memory_profiler
vit_pytorch
Cite
If you use this code in your project, please cite:
Michael E. Sander, Pierre Ablin, Mathieu Blondel, Gabriel Peyré Momentum Residual Neural Networks Proceedings of the 38th International Conference on Machine Learning, PMLR 139:9276-9287 https://arxiv.org/abs/2102.07870
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for momentumnet-0.10-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | d40789aaeaf5fdfd0896aef67fb22e210ff9cbb023a4a139a51d8076551f781a |
|
MD5 | 0ff7381a8cdc0ee494e7ea3786157236 |
|
BLAKE2b-256 | d950146926a7e728901a7849982293442f8b0061062fcc4dcb54a55b9df819b2 |