Skip to main content

InvTorch: Memory-Efficient Invertible Functions

Project description

InvTorch: Memory-Efficient Invertible Functions

When working with extremely deep neural networks, memory becomes an immediate concern. It will be filled with saved tensors that are needed for gradient computation. PyTorch provides a solution, checkpoint_sequential, that allows us to segment every few layers as a checkpoint. Such that, the forward pass is run in no_grad mode. Meanwhile, the inputs of every segment is saved in memory. Every other unreferenced tensor gets deallocated. In the backward pass, the forward pass of every segment, starting from the last to the first, will be run again using its saved inputs to compute its gradients. Refer to this for more details.

This module extends the functionality of torch.utils.checkpoint.checkpoint to work with invertible functions. So, everything now can be released from memory and recomputed later using the inverse function in the backward pass. This is useful for extremely wide networks where more compute is traded with memory. However, there are few considerations to keep in mind when working with invertible checkpoints and non-materialized tensors. Please, read the limitations section below for specifics.

Installation

InvTorch has minimal dependencies. It only requires Python >=3.6 and PyTorch >=1.10.0.

pip install invtorch

Usage

We are interested in invtorch.nn.Module which inherits from torch.nn.Module. Subclass it to implement your own invertible code.

import torch
from torch import nn

import invtorch.nn as inn


class InvertibleLinear(inn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        self.bias = nn.Parameter(torch.randn(out_features))

    def function(self, inputs):
        return inputs @ self.weight.T + self.bias

    def inverse(self, outputs):
        return (outputs - self.bias) @ self.weight.T.pinverse()


if __name__ == '__main__':
    x = torch.randn(10, 3)
    model = InvertibleLinear(3, 5)
    print('Is invertible:', model.check(x))

    y = model(x)
    print('Input was freed:', x.storage().size() == 0)

    y.backward(torch.randn_like(y))
    print('Input was restored:', x.storage().size() != 0)

forward()

You can immediately notice few differences to the regular PyTorch module here. There is no longer a need to define forward(). Instead, it is replaced with function(). Additionally, it is necessary to define its inverse function as inverse(). Both methods should input and output only positional arguments as a tuple or a single torch.Tensor. Furthermore, forward() accepts an optional keyword argument keep which is an iterable of input tensors that shouldn't be deallocated.

function()

The first call to function() is always run in dry_mode (check it with invtorch.in_dry_mode()). This is a novel mode that has gradient graph construction enabled but doesn't allow for backward propagation. The second call is during the backward pass which is when the gradients will actually be computed (check it with invtorch.in_backward()).

inverse()

You can verify your implementation of inverse() by calling check(). In some cases, switching to double precision is advised as invertible functions can run into some numerical instability when using single precision. For some functions, a view of an input tensor is passed in the outputs. In such case, this will be automatically detected and the input tensor will not be released from memory.

num_outputs

Sometimes, inverse() requires auxiliary outputs from function() that are not necessarily needed as an output for forward(). To specify how many outputs do we need to keep, we can modify the num_outputs property (None keeps everything). Extra outputs will be hidden using process() which is automatically called by forward(). Another property num_inputs is also defined which helps composing inverse calls like in invtorch.nn.Sequential.inverse().

reverse()

invtorch.nn.Module can be implemented to be reversible, i.e. forward() will call inverse() instead of function(). Not all invertible modules need to support reversibility. If you want to support it on your own module, then you need to override the reversible property to return True. The module can be reversed by calling reverse() and checked with the reversed property. To avoid confusion, Module has call_function() and call_inverse() which will call the correct function based on the value of reversed.

checkpoint, invertible, and seed

invtorch.nn.Module has two flags which control the mode of operation; checkpoint and invertible. If checkpoint was set to False, or when working in no_grad mode, it acts exactly as a normal PyTorch module. Otherwise, the model is either invertible or a checkpoint depending on whether invertible is set to True or False, respectively. Those, flags can be changed at any time during operation without any repercussions. A third flag seed is by default False and if set to True, it ensures that the forward runs in the same random number generator's state of the devices of the input tensors.

invtorch.checkpoint()

PyTorch's checkpoint cannot track the requires_grad attribute for its output tensors since it is running in no_grad mode. Instead, InvTorch's checkpoint doesn't have this issue because it runs in dry_mode. In addition, it supports invertible functions.

Limitations

There are few caveats to consider though. Invertible functions are hard to define without requiring more memory. Moreover, they are prone to numerical instabilities (e.g., multiplying by numbers close to zero). Even if we can get away with these fundamental problems, there are technical details to consider. There is no way of guarding against accessing the data in the input tensors after calling function() and before the backward pass. It is up to the user to ensure this. Otherwise, it is possible to run into illegal memory access errors. Think of residual connections as an example. In x + f(x), assuming f is an invertible checkpoint, x will be freed from memory before the sum is computed. On the other hand, we can maybe use x.clone() + f(x) (not f(x) + x.clone()!) but now we have a copy of x in memory. It is recommended to encapsulate this inside f itself or use a simple checkpoint instead. Other alternatives exists and you should study your case carefully before deciding to use this. For instance, check out torch.autograd.graph.saved_tensors_hooks() and torch.autograd.graph.save_on_cpu().

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

invtorch-0.5.0.tar.gz (21.6 kB view details)

Uploaded Source

Built Distribution

invtorch-0.5.0-py3-none-any.whl (24.4 kB view details)

Uploaded Python 3

File details

Details for the file invtorch-0.5.0.tar.gz.

File metadata

  • Download URL: invtorch-0.5.0.tar.gz
  • Upload date:
  • Size: 21.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.8.1 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.2 CPython/3.9.7

File hashes

Hashes for invtorch-0.5.0.tar.gz
Algorithm Hash digest
SHA256 940500b3b15fe68d08ae2f7636e9986b46a8179612f022a6671869906ab427c0
MD5 ed6969167f3a840c9e0f9fbcfe2b1ed3
BLAKE2b-256 22276143b6e44ae775dd85ff619019b0c75f73aa7bc712b1ffb8860159d2da98

See more details on using hashes here.

File details

Details for the file invtorch-0.5.0-py3-none-any.whl.

File metadata

  • Download URL: invtorch-0.5.0-py3-none-any.whl
  • Upload date:
  • Size: 24.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.8.1 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.2 CPython/3.9.7

File hashes

Hashes for invtorch-0.5.0-py3-none-any.whl
Algorithm Hash digest
SHA256 384b4d4f282b4ce27cf30c1ad9288598534e93b52fca74052c141fe67dfd039c
MD5 76b220f2ecf2a890fedcfd19fded1c32
BLAKE2b-256 c3b82165a0ba736465ae26482f37cae1ed666ddb6d13289001bdeb5e097a7b4f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page