Skip to main content

A PyTorch framework for developing memory efficient deep invertible networks.

Project description

CircleCI Documentation Status Codecov branch https://img.shields.io/pypi/v/memcnn.svg PyPI - Implementation PyPI - Python Version GitHub

A PyTorch framework for developing memory efficient deep invertible networks

Reference: Sil C. van de Leemput, Jonas Teuwen, Rashindra Manniesing. MemCNN: a Framework for Developing Memory Efficient Deep Invertible Networks. International Conference on Learning Representations (ICLR) 2018 Workshop Track. (https://iclr.cc/)

Licencing

This repository comes with the MIT license, which implies everyone has the right to use, copy, distribute and/or modify this work. If you do, please cite our work.

Features

  • Simple ReversibleBlock wrapper class to wrap and convert arbitrary PyTorch Modules to invertible versions.

  • Simple switching between additive and affine invertible coupling schemes and different implementations.

  • Simple toggling of memory saving by setting the keep_input property of the ReversibleBlock.

  • Train and evaluation code for reproducing RevNet experiments using MemCNN.

  • CI tests for Python v2.7 and v3.6 and torch v0.4, v1.0, and v1.1.

Example usage: ReversibleBlock

# some required imports
import torch
import torch.nn as nn
import numpy as np
import memcnn.models.revop


# define a new class of operation(s) PyTorch style
class ExampleOperation(nn.Module):
    def __init__(self, channels):
        super(ExampleOperation, self).__init__()
        self.seq = nn.Sequential(
                                    nn.Conv2d(in_channels=channels, out_channels=channels,
                                              kernel_size=(3, 3), padding=1),
                                    nn.BatchNorm2d(num_features=channels),
                                    nn.ReLU(inplace=True)
                                )

    def forward(self, x):
        return self.seq(x)


# generate some random input data (b, c, y, x)
data = np.random.random((2, 10, 8, 8)).astype(np.float32)
X = torch.from_numpy(data)

# application of the operation(s) the normal way
Y = ExampleOperation(channels=10)(X)

# application of the operation(s) using the reversible block
F, G = ExampleOperation(channels=10 // 2), ExampleOperation(channels=10 // 2)
Y = memcnn.models.revop.ReversibleBlock(F, G, coupling='additive')(X)

Run PyTorch Experiments

./train.py [MODEL] [DATASET] --fresh

Available values for DATASET are cifar10 and cifar100.

Available values for MODEL are resnet32, resnet110, resnet164, revnet38, revnet110, revnet164

If not available datasets are automatically downloaded.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

memcnn-0.2.1.tar.gz (11.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

memcnn-0.2.1-py2.py3-none-any.whl (5.7 kB view details)

Uploaded Python 2Python 3

File details

Details for the file memcnn-0.2.1.tar.gz.

File metadata

  • Download URL: memcnn-0.2.1.tar.gz
  • Upload date:
  • Size: 11.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/40.6.2 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.6.8

File hashes

Hashes for memcnn-0.2.1.tar.gz
Algorithm Hash digest
SHA256 61274cf6678b1dbb76b42c1158bfd0d9376e812fd7c6ef388d83fa27b5ce97b5
MD5 5a45b8e710c9e4e708622b69f3423d0c
BLAKE2b-256 abbb4e01f3a42632ec0c981c1ff5eeb78ff6884dfabd2d66ee773447be75e280

See more details on using hashes here.

File details

Details for the file memcnn-0.2.1-py2.py3-none-any.whl.

File metadata

  • Download URL: memcnn-0.2.1-py2.py3-none-any.whl
  • Upload date:
  • Size: 5.7 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/40.6.2 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.6.8

File hashes

Hashes for memcnn-0.2.1-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 bcd9524786089c6937f465c147da542e85411e814e4a98c0190dd0d8c17d6b5d
MD5 e997a23f42301f35a22e21b17265e6f2
BLAKE2b-256 6c880db1f9c1fd0a7f21384c829d6fc99a1b0967b548ab994564b3e7f001f1ec

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page