A PyTorch framework for developing memory efficient deep invertible networks.
Project description
A PyTorch framework for developing memory efficient deep invertible networks
Free software: MIT license
Documentation: https://memcnn.readthedocs.io.
Installation: https://memcnn.readthedocs.io/en/latest/installation.html
Reference: Sil C. van de Leemput, Jonas Teuwen, Rashindra Manniesing. MemCNN: a Framework for Developing Memory Efficient Deep Invertible Networks. International Conference on Learning Representations (ICLR) 2018 Workshop Track. (https://iclr.cc/)
Licencing
This repository comes with the MIT license, which implies everyone has the right to use, copy, distribute and/or modify this work. If you do, please cite our work.
Features
Simple ReversibleBlock wrapper class to wrap and convert arbitrary PyTorch Modules to invertible versions.
Simple switching between additive and affine invertible coupling schemes and different implementations.
Simple toggling of memory saving by setting the keep_input property of the ReversibleBlock.
Train and evaluation code for reproducing RevNet experiments using MemCNN.
CI tests for Python v2.7 and v3.6 and torch v0.4, v1.0, and v1.1.
Example usage: ReversibleBlock
# some required imports
import torch
import torch.nn as nn
import numpy as np
import memcnn.models.revop
# define a new class of operation(s) PyTorch style
class ExampleOperation(nn.Module):
def __init__(self, channels):
super(ExampleOperation, self).__init__()
self.seq = nn.Sequential(
nn.Conv2d(in_channels=channels, out_channels=channels,
kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(num_features=channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.seq(x)
# generate some random input data (b, c, y, x)
data = np.random.random((2, 10, 8, 8)).astype(np.float32)
X = torch.from_numpy(data)
# application of the operation(s) the normal way
Y = ExampleOperation(channels=10)(X)
# application of the operation(s) using the reversible block
F, G = ExampleOperation(channels=10 // 2), ExampleOperation(channels=10 // 2)
Y = memcnn.models.revop.ReversibleBlock(F, G, coupling='additive')(X)
Run PyTorch Experiments
./train.py [MODEL] [DATASET] --fresh
Available values for DATASET are cifar10 and cifar100.
Available values for MODEL are resnet32, resnet110, resnet164, revnet38, revnet110, revnet164
If not available datasets are automatically downloaded.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for memcnn-0.2.1-py2.py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | bcd9524786089c6937f465c147da542e85411e814e4a98c0190dd0d8c17d6b5d |
|
MD5 | e997a23f42301f35a22e21b17265e6f2 |
|
BLAKE2b-256 | 6c880db1f9c1fd0a7f21384c829d6fc99a1b0967b548ab994564b3e7f001f1ec |