Skip to main content

A PyTorch framework for developing memory efficient deep invertible networks.

Project description

CircleCI - Status master branch Docker - Status Documentation - Status master branch Codacy - Branch grade Codecov - Status master branch PyPI - Latest release Conda - Latest release PyPI - Implementation PyPI - Python version GitHub - Repository license JOSS - DOI

A PyTorch framework for developing memory-efficient invertible neural networks.

Features

  • Enable memory savings during training by wrapping arbitrary invertible PyTorch functions with the InvertibleModuleWrapper class.

  • Simple toggling of memory saving by setting the keep_input property of the InvertibleModuleWrapper.

  • Turn arbitrary non-linear PyTorch functions into invertible versions using the AdditiveCoupling or the AffineCoupling classes.

  • Training and evaluation code for reproducing RevNet experiments using MemCNN.

  • CI tests for Python v3.7 and torch v1.0, v1.1, v1.4 and v1.7 with good code coverage.

Examples

Creating an AdditiveCoupling with memory savings

import torch
import torch.nn as nn
import memcnn


# define a new torch Module with a sequence of operations: Relu o BatchNorm2d o Conv2d
class ExampleOperation(nn.Module):
    def __init__(self, channels):
        super(ExampleOperation, self).__init__()
        self.seq = nn.Sequential(
                                    nn.Conv2d(in_channels=channels, out_channels=channels,
                                              kernel_size=(3, 3), padding=1),
                                    nn.BatchNorm2d(num_features=channels),
                                    nn.ReLU(inplace=True)
                                )

    def forward(self, x):
        return self.seq(x)


# generate some random input data (batch_size, num_channels, y_elements, x_elements)
X = torch.rand(2, 10, 8, 8)

# application of the operation(s) the normal way
model_normal = ExampleOperation(channels=10)
model_normal.eval()

Y = model_normal(X)

# turn the ExampleOperation invertible using an additive coupling
invertible_module = memcnn.AdditiveCoupling(
    Fm=ExampleOperation(channels=10 // 2),
    Gm=ExampleOperation(channels=10 // 2)
)

# test that it is actually a valid invertible module (has a valid inverse method)
assert memcnn.is_invertible_module(invertible_module, test_input_shape=X.shape)

# wrap our invertible_module using the InvertibleModuleWrapper and benefit from memory savings during training
invertible_module_wrapper = memcnn.InvertibleModuleWrapper(fn=invertible_module, keep_input=True, keep_input_inverse=True)

# by default the module is set to training, the following sets this to evaluation
# note that this is required to pass input tensors to the model with requires_grad=False (inference only)
invertible_module_wrapper.eval()

# test that the wrapped module is also a valid invertible module
assert memcnn.is_invertible_module(invertible_module_wrapper, test_input_shape=X.shape)

# compute the forward pass using the wrapper
Y2 = invertible_module_wrapper.forward(X)

# the input (X) can be approximated (X2) by applying the inverse method of the wrapper on Y2
X2 = invertible_module_wrapper.inverse(Y2)

# test that the input and approximation are similar
assert torch.allclose(X, X2, atol=1e-06)

Run PyTorch Experiments

After installing MemCNN run:

python -m memcnn.train [MODEL] [DATASET] [--fresh] [--no-cuda]
  • Available values for DATASET are cifar10 and cifar100.

  • Available values for MODEL are resnet32, resnet110, resnet164, revnet38, revnet110, revnet164

  • Use the --fresh flag to remove earlier experiment results.

  • Use the --no-cuda flag to train on the CPU rather than the GPU through CUDA.

Datasets are automatically downloaded if they are not available.

When using Python 3.* replace the python directive with the appropriate Python 3 directive. For example when using the MemCNN docker image use python3.6.

When MemCNN was installed using pip or from sources you might need to setup a configuration file before running this command. Read the corresponding section about how to do this here: https://memcnn.readthedocs.io/en/latest/installation.html

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

memcnn-1.5.2.tar.gz (50.6 kB view details)

Uploaded Source

Built Distribution

memcnn-1.5.2-py2.py3-none-any.whl (50.3 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file memcnn-1.5.2.tar.gz.

File metadata

  • Download URL: memcnn-1.5.2.tar.gz
  • Upload date:
  • Size: 50.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.4

File hashes

Hashes for memcnn-1.5.2.tar.gz
Algorithm Hash digest
SHA256 5a090f58867fbe0463d81c7ef38868db80886c7d1abbc085676153d7887d824a
MD5 9218c7318278a6faf5807e1bdba547ea
BLAKE2b-256 01ca469972769f97393650e1d10b6ea394fbb7475c73c22fd107640538a92bdc

See more details on using hashes here.

File details

Details for the file memcnn-1.5.2-py2.py3-none-any.whl.

File metadata

  • Download URL: memcnn-1.5.2-py2.py3-none-any.whl
  • Upload date:
  • Size: 50.3 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.4

File hashes

Hashes for memcnn-1.5.2-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 6414a70e18564e3b3354cebcf86b151f778db03f3b55b2719bbf83d3ba71d9dc
MD5 5244faf6f3094bb49892ecd95821681a
BLAKE2b-256 5f5a02e2609a2e4e9ab475ea20d50796f6435882b4bca53add92ce2865e4724b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page