Skip to main content

A PyTorch framework for developing memory efficient deep invertible networks.

Project description

CircleCI Documentation Status Codecov branch https://img.shields.io/pypi/v/memcnn.svg PyPI - Implementation PyPI - Python Version GitHub

A PyTorch framework for developing memory efficient deep invertible networks

Features

  • Simple ReversibleBlock wrapper class to wrap and convert arbitrary PyTorch Modules into invertible versions.

  • Simple switching between additive and affine invertible coupling schemes and different implementations.

  • Simple toggling of memory saving by setting the keep_input property of the ReversibleBlock.

  • Training and evaluation code for reproducing RevNet experiments using MemCNN.

  • CI tests for Python v2.7 and v3.6 and torch v0.4, v1.0, and v1.1 and good test coverage.

Example usage: ReversibleBlock

# some required imports
import torch
import torch.nn as nn
import numpy as np
import memcnn.models.revop


# define a new class of operation(s) PyTorch style
class ExampleOperation(nn.Module):
    def __init__(self, channels):
        super(ExampleOperation, self).__init__()
        self.seq = nn.Sequential(
                                    nn.Conv2d(in_channels=channels, out_channels=channels,
                                              kernel_size=(3, 3), padding=1),
                                    nn.BatchNorm2d(num_features=channels),
                                    nn.ReLU(inplace=True)
                                )

    def forward(self, x):
        return self.seq(x)


# generate some random input data (b, c, y, x)
data = np.random.random((2, 10, 8, 8)).astype(np.float32)
X = torch.from_numpy(data)

# application of the operation(s) the normal way
Y = ExampleOperation(channels=10)(X)

# application of the operation(s) using the reversible block
F, G = ExampleOperation(channels=10 // 2), ExampleOperation(channels=10 // 2)
Y = memcnn.models.revop.ReversibleBlock(F, G, coupling='additive')(X)

Run PyTorch Experiments

./train.py [MODEL] [DATASET] --fresh

Available values for DATASET are cifar10 and cifar100.

Available values for MODEL are resnet32, resnet110, resnet164, revnet38, revnet110, revnet164

If not available datasets are automatically downloaded.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

memcnn-0.3.1.tar.gz (38.8 kB view hashes)

Uploaded Source

Built Distribution

memcnn-0.3.1-py2.py3-none-any.whl (41.8 kB view hashes)

Uploaded Python 2 Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page