Skip to main content

Framework for creating (partially) reversible neural networks with PyTorch

Project description

RevTorch

Framework for creating (partially) reversible neural networks with PyTorch

RevTorch is introduced and explained in our paper A Partially Reversible U-Net for Memory-Efficient Volumetric Image Segmentation, which was accepted for presentation at MICCAI 2019.

If you find this code helpful in your research please cite the following paper:

@article{PartiallyRevUnet2019Bruegger,
         author={Br{\"u}gger, Robin and Baumgartner, Christian F.
         and Konukoglu, Ender},
         title={A Partially Reversible U-Net for Memory-Efficient Volumetric Image Segmentation},
         journal={arXiv:1906.06148},
         year={2019},

Installation

Use pip to install RevTorch:

$ pip install revtorch

RevTorch requires PyTorch. However, PyTorch is not included in the dependencies since the required PyTorch version is dependent on your system. Please install PyTorch following the instructions on the PyTorch website.

Usage

This example shows how to use the RevTorch framework.

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import revtorch as rv

def train():
    trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transforms.ToTensor())
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)

    net = PartiallyReversibleNet()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters())

    for epoch in range(2):

        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            #logging stuff
            running_loss += loss.item()
            LOG_INTERVAL = 200
            if i % LOG_INTERVAL == (LOG_INTERVAL-1):  # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / LOG_INTERVAL))
                running_loss = 0.0

class PartiallyReversibleNet(nn.Module):
    def __init__(self):
        super(PartiallyReversibleNet, self).__init__()

        #initial non-reversible convolution to get to 32 channels
        self.conv1 = nn.Conv2d(3, 32, 3)

        #construct reversible sequencce with 4 reversible blocks
        blocks = []
        for i in range(4):

            #f and g must both be a nn.Module whos output has the same shape as its input
            f_func = nn.Sequential(nn.ReLU(), nn.Conv2d(16, 16, 3, padding=1))
            g_func = nn.Sequential(nn.ReLU(), nn.Conv2d(16, 16, 3, padding=1))

            #we construct a reversible block with our F and G functions
            blocks.append(rv.ReversibleBlock(f_func, g_func))

        #pack all reversible blocks into a reversible sequence
        self.sequence = rv.ReversibleSequence(nn.ModuleList(blocks))

        #non-reversible convolution to get to 10 channels (one for each label)
        self.conv2 = nn.Conv2d(32, 10, 3)

    def forward(self, x):
        x = self.conv1(x)

        #the reversible sequence can be used like any other nn.Module. Memory-saving backpropagation is used automatically
        x = self.sequence(x)

        x = self.conv2(F.relu(x))
        x = F.avg_pool2d(x, (x.shape[2], x.shape[3]))
        x = x.view(x.shape[0], x.shape[1])
        return x

if __name__ == "__main__":
    train()

Python version

Tested with Python 3.6 and PyTorch 1.1.0. Should work with any version of Python 3.

Changelog

Version 0.2.4

  • Added option to disable eager discarding of variables to allow for multiple backward() calls

Version 0.2.3

  • Added option to use the same random seed for the forward and backwar pass (Pull request)

Version 0.2.1

  • Added option to select the dimension along which the tensor is split (Pull request)

Version 0.2.0

  • Fixed memory leak when not consuming output of the reversible block (Issue)

Version 0.1.0

  • Initial release

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

revtorch-0.2.4.tar.gz (5.6 kB view details)

Uploaded Source

File details

Details for the file revtorch-0.2.4.tar.gz.

File metadata

  • Download URL: revtorch-0.2.4.tar.gz
  • Upload date:
  • Size: 5.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.12.4 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.28.1 CPython/3.6.7

File hashes

Hashes for revtorch-0.2.4.tar.gz
Algorithm Hash digest
SHA256 b028324b7430a7ee9311b0472e7bf741e40c51f8c266fb0cd2fd032f2679b0ec
MD5 bd103589b8b674e7a2a9b00030357c99
BLAKE2b-256 7b7f6b2247e5ce4b8969dedfcaec064c59ce0417cddbe638bfa6169ff586eaea

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page