A PyTorch framework for developing memory efficient deep invertible networks.
Project description
A PyTorch framework for developing memory-efficient invertible neural networks.
- Free software: MIT license (please cite our work if you use it)
- Documentation: https://memcnn.readthedocs.io.
- Installation: https://memcnn.readthedocs.io/en/latest/installation.html
Features
- Enable memory savings during training by wrapping arbitrary invertible PyTorch functions with the InvertibleModuleWrapper class.
- Simple toggling of memory saving by setting the keep_input property of the InvertibleModuleWrapper.
- Turn arbitrary non-linear PyTorch functions into invertible versions using the AdditiveCoupling or the AffineCoupling classes.
- Training and evaluation code for reproducing RevNet experiments using MemCNN.
- CI tests for Python v3.7 and torch v1.0, v1.1, v1.4 and v1.7 with good code coverage.
Examples
Creating an AdditiveCoupling with memory savings
import torch import torch.nn as nn import memcnn # define a new torch Module with a sequence of operations: Relu o BatchNorm2d o Conv2d class ExampleOperation(nn.Module): def __init__(self, channels): super(ExampleOperation, self).__init__() self.seq = nn.Sequential( nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(3, 3), padding=1), nn.BatchNorm2d(num_features=channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.seq(x) # generate some random input data (batch_size, num_channels, y_elements, x_elements) X = torch.rand(2, 10, 8, 8) # application of the operation(s) the normal way model_normal = ExampleOperation(channels=10) model_normal.eval() Y = model_normal(X) # turn the ExampleOperation invertible using an additive coupling invertible_module = memcnn.AdditiveCoupling( Fm=ExampleOperation(channels=10 // 2), Gm=ExampleOperation(channels=10 // 2) ) # test that it is actually a valid invertible module (has a valid inverse method) assert memcnn.is_invertible_module(invertible_module, test_input_shape=X.shape) # wrap our invertible_module using the InvertibleModuleWrapper and benefit from memory savings during training invertible_module_wrapper = memcnn.InvertibleModuleWrapper(fn=invertible_module, keep_input=True, keep_input_inverse=True) # by default the module is set to training, the following sets this to evaluation # note that this is required to pass input tensors to the model with requires_grad=False (inference only) invertible_module_wrapper.eval() # test that the wrapped module is also a valid invertible module assert memcnn.is_invertible_module(invertible_module_wrapper, test_input_shape=X.shape) # compute the forward pass using the wrapper Y2 = invertible_module_wrapper.forward(X) # the input (X) can be approximated (X2) by applying the inverse method of the wrapper on Y2 X2 = invertible_module_wrapper.inverse(Y2) # test that the input and approximation are similar assert torch.allclose(X, X2, atol=1e-06)
Run PyTorch Experiments
After installing MemCNN run:
python -m memcnn.train [MODEL] [DATASET] [--fresh] [--no-cuda]
- Available values for DATASET are cifar10 and cifar100.
- Available values for MODEL are resnet32, resnet110, resnet164, revnet38, revnet110, revnet164
- Use the --fresh flag to remove earlier experiment results.
- Use the --no-cuda flag to train on the CPU rather than the GPU through CUDA.
Datasets are automatically downloaded if they are not available.
When using Python 3.* replace the python directive with the appropriate Python 3 directive. For example when using the MemCNN docker image use python3.6.
When MemCNN was installed using pip or from sources you might need to setup a configuration file before running this command. Read the corresponding section about how to do this here: https://memcnn.readthedocs.io/en/latest/installation.html
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Filename, size | File type | Python version | Upload date | Hashes |
---|---|---|---|---|
Filename, size memcnn-1.5.0-py2.py3-none-any.whl (50.2 kB) | File type Wheel | Python version py2.py3 | Upload date | Hashes View |
Filename, size memcnn-1.5.0.tar.gz (45.9 kB) | File type Source | Python version None | Upload date | Hashes View |
Hashes for memcnn-1.5.0-py2.py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | a9d3150bdf5941ebc3ff2f74ca786c19d81c4c57e222b2d3ad153427f2192c85 |
|
MD5 | 6390b8707212d569af1329d220a0471a |
|
BLAKE2-256 | 496bdf4aad8829ead84ebdd04f1383897e7f0208c8908da73ba1a5a6a1e53e0d |