Skip to main content

Simple and efficient RevNet-Library with DeepSpeed support

Project description

RevLib

Simple and efficient RevNet-Library with DeepSpeed support

Features

  • Half the constant memory usage and faster than RevNet libraries
  • Less memory than gradient checkpointing (1 * output_size instead of (n_layers + 1) * output_size)
  • Same speed as activation checkpointing
  • Extensible
  • Trivial code (<100 Lines)

Getting started

Installation

python3 -m pip install revlib

Examples

Sequential CNN like iRevNet

iRevNet is not only partially reversible but instead a fully-invertible model. The source code looks complex at first glance. It also doesn't use the memory savings it could utilize, as RevNet requires custom AutoGrad functions that are hard to maintain. An iRevNet can be implemented like this using revlib:

import torch
from torch import nn
import revlib

channels = 64
channel_multiplier = 4
depth = 16
classes = 1000


# Create a basic function that's reversibly executed multiple times. (Like f() in ResNet)
def conv(in_channels, out_channels):
    return nn.Conv2d(in_channels, out_channels, (3, 3), padding=1)


def block_conv(in_channels, out_channels):
    return nn.Sequential(conv(in_channels, out_channels),
                         nn.Dropout(0.2),
                         nn.BatchNorm2d(channels),
                         nn.ReLU())


def block():
    return nn.Sequential(block_conv(channels, channels * channel_multiplier),
                         block_conv(channels * channel_multiplier, channels),
                         nn.Conv2d(channels, channels, (3, 3), padding=1))


# Create a reversible model. f() is invoked depth-times with different weights.
rev_model = revlib.ReversibleSequential(*[block() for _ in range(depth)])

# Wrap reversible model with non-reversible layers
model = nn.Sequential(nn.Conv2d(3, 2 * channels, (3, 3)),
                      rev_model,
                      nn.Conv2d(2 * channels, classes, (3, 3)))

# Use it like you would a regular PyTorch model
inp = torch.randn((16, 3, 224, 224))
out = model(inp)
assert out.size() == (16, 1000, 224, 224)

MomentumNet

MomentumNet is another recent paper that made significant advancements in the area of memory-efficient networks. They propose to use a momentum stream instead of a second model output as illustrated below: MomentumNetIllustration. Implementing that with revlib requires you to write a custom coupling operation (functional analogue to MemCNN) that merges input and output streams.

import torch
from torch import nn
import revlib

channels = 64
depth = 16
momentum_ema_beta = 0.99


# Compute x2_{t+1} from x2_{t} and f(x1_{t}) by merging x2 and f(x1) in the forward pass.
def momentum_coupling_forward(other_stream: torch.Tensor, fn_out: torch.Tensor) -> torch.Tensor:
    return other_stream * momentum_ema_beta + fn_out * (1 - momentum_ema_beta)


# Calculate x2_{t} from x2_{t+1} and f(x1_{t}) by manually computing the inverse of momentum_coupling_forward.
def momentum_coupling_inverse(output: torch.Tensor, fn_out: torch.Tensor) -> torch.Tensor:
    return output / momentum_ema_beta - fn_out * (1 - momentum_ema_beta)


# Pass in coupling functions which will be used instead of x2_{t} + f(x1_{t}) and x2_{t+1} - f(x1_{t})
rev_model = revlib.ReversibleSequential(*[nn.Conv2d(channels, channels, (3, 3), padding=1) for _ in range(depth)],
                                        coupling_forward=momentum_coupling_forward,
                                        coupling_inverse=momentum_coupling_inverse)

inp = torch.randn((16, channels * 2, 224, 224))
out = rev_model(inp)
assert out.size() == (16, channels * 2, 224, 224)

Reformer

Reformer uses RevNet with chunking and LSH-attention to efficiently train a transformer. Using revlib, standard implementations, such as lucidrains' Reformer, can be improved upon to use less memory. Below we're still using the basic building blocks from lucidrains' code to have a comparable model.

import torch
from torch import nn
from reformer_pytorch.reformer_pytorch import LSHSelfAttention, Chunk, FeedForward, AbsolutePositionalEmbedding
import revlib


class Reformer(torch.nn.Module):
    def __init__(self, sequence_length: int, features: int, depth: int, heads: int, bucket_size: int = 64,
                 lsh_hash_count: int = 8, ff_chunks: int = 16, input_classes: int = 256, output_classes: int = 256):
        super(Reformer, self).__init__()
        self.token_embd = nn.Embedding(input_classes, features * 2)
        self.pos_embd = AbsolutePositionalEmbedding(features * 2, sequence_length)

        self.core = revlib.ReversibleSequential(*[nn.Sequential(nn.LayerNorm(features), layer) for _ in range(depth)
                                                 for layer in
                                                 [LSHSelfAttention(features, heads, bucket_size, lsh_hash_count),
                                                  Chunk(ff_chunks, FeedForward(features, activation=nn.GELU), 
                                                        along_dim=-2)]],
                                                split_dim=-1)
        self.out_norm = nn.LayerNorm(features * 2)
        self.out_linear = nn.Linear(features * 2, output_classes)

    def forward(self, inp: torch.Tensor) -> torch.Tensor:
        return self.out_linear(self.out_norm(self.core(self.token_embd(inp) + self.pos_embd(inp))))


sequence = 1024
classes = 16
model = Reformer(sequence, 256, 6, 8, output_classes=classes)
out = model(torch.ones((16, sequence), dtype=torch.long))
assert out.size() == (16, sequence, classes)

Explanation

Most other RevNet libraries, such as MemCNN and Revtorch calculate both f() and g() in one go, to create one large computation. RevLib, on the other hand, brings Mesh TensorFlow's "reversible half residual and swap" to PyTorch. reversible_half_residual_and_swap computes only one of f() and g() and swaps the inputs and gradients. This way, the library only has to store one output as it can recover the other output during the backward pass.
Following Mesh TensorFlow's example, revlib also uses separate x1 and x2 tensors instead of concatenating and splitting at every step to reduce the cost of memory-bound operations.

RevNet's memory consumption doesn't scale with its depth, so it's significantly more memory-efficient for deep models. One problem in most implementations was that two tensors needed to be stored in the output, quadrupling the required memory. The high memory consumption rendered RevNet nearly useless for small networks, such as BERT, with its six layers.
RevLib works around this problem by storing only one output and two inputs for each forward pass, giving a model as small as BERT a >2x improvement!

Ignoring the dual-path structure of a RevNet, it usually used to be much slower than gradient checkpointing. However, RevLib uses minimal coupling functions and has no overhead between Sequence items, allowing it to train as fast as a comparable model with gradient checkpointing.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

revlib-1.0.1.tar.gz (6.0 kB view details)

Uploaded Source

Built Distribution

revlib-1.0.1-py3-none-any.whl (6.1 kB view details)

Uploaded Python 3

File details

Details for the file revlib-1.0.1.tar.gz.

File metadata

  • Download URL: revlib-1.0.1.tar.gz
  • Upload date:
  • Size: 6.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.4.2 requests/2.25.1 setuptools/45.2.0 requests-toolbelt/0.8.0 tqdm/4.61.1 CPython/3.8.10

File hashes

Hashes for revlib-1.0.1.tar.gz
Algorithm Hash digest
SHA256 f40e94b12d793f9373a43dfbc70f5eb7ab7ba56a9c8dd2feb2eb8db96afdc894
MD5 78098f61e09708c05e229b5f2862996c
BLAKE2b-256 8205d7aa3fca49256bce635299c0b2b2f752e94925e2c1a050b6868ee20bdf7b

See more details on using hashes here.

File details

Details for the file revlib-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: revlib-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 6.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.4.2 requests/2.25.1 setuptools/45.2.0 requests-toolbelt/0.8.0 tqdm/4.61.1 CPython/3.8.10

File hashes

Hashes for revlib-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 8c706730eb686b0947c7609208ef0cd650cee69424aadc2aed2e94c4dc74f7a5
MD5 f0201017a1a939b420fce28256ff5472
BLAKE2b-256 a676635fb2d3b493148d7c2f282bb8e06f9da7d3ca389efd278e46b9085cc375

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page