Skip to main content

Patch convolution to avoid large GPU memory usage of Conv2D

Project description

Patch Conv: Patch convolution to avoid large GPU memory usage of Conv2D [Blog]

patch_conv

Background

In current generative models, we usually apply convolutions over large-size activations to generate high-resolution content. However, PyTorch tends to use excessive memory for these operations, potentially leading to memory shortages even on 80GB A100 GPUs.

As shown in the figure below, memory demands for standard PyTorch convolutions drastically increase when the input size reaches 1B parameters (channel×height×width). Notably, with a kernel size of 7×7, the 80GB A100 GPUs would trigger Out of Memory (OOM) errors. Inputs exceeding 2B parameters can further cause 3×3 convolutions exhaust all the memory and that’s just for one layer! This memory bottleneck prevents users and the community from scaling up the models to produce high-quality images.

To bypass this issue and reduce memory consumption, we propose a simple and effective solution -- Patch Conv. As shown in the above figure, similar to SIGE, Patch Conv first divides the input into several smaller patches along the height dimension while keeping some overlap between them. These patches are then reorganized into the batch dimension and fed into the original convolution to produce output patches, which are then concatenated together to form the final output. Patch Conv can reduce memory usage by over 2.4×, providing a viable workaround for the limitations of current implementations.

Installation

After installing PyTorch, you can install PatchConv from PyPI:

pip install patch_conv

or via GitHub:

pip install git+https://github.com/mit-han-lab/patch_conv.git

or locally for development:

git clone git@github.com:mit-han-lab/patch_conv.git
cd patch_conv
pip install -e .

Usage

All you need to do is use convert_model to wrap all the Conv2d in your PyTorch model to our PatchConv. For example,

from patch_conv import convert_model

model = Model(...)  # Your PyTorch model
model = convert_model(model, splits=4)  # The only modification you need to make

with torch.no_grad():
    model(...)  # Run the model in the original way

Performance

performance

Patch Conv significantly reduces memory consumption by over 2.4× across various kernel sizes and input resolutions with a marginally slower inference speed compared to vanilla convolution.

Related Projects

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

patch_conv-0.0.1b0.tar.gz (4.7 kB view details)

Uploaded Source

Built Distribution

patch_conv-0.0.1b0-py3-none-any.whl (5.2 kB view details)

Uploaded Python 3

File details

Details for the file patch_conv-0.0.1b0.tar.gz.

File metadata

  • Download URL: patch_conv-0.0.1b0.tar.gz
  • Upload date:
  • Size: 4.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.2

File hashes

Hashes for patch_conv-0.0.1b0.tar.gz
Algorithm Hash digest
SHA256 3c114d79a71b68e6d99254970d29c4fb7521d020483141dc932fc46b859a6b0c
MD5 c04a4f4f13f5183a7e6b83513a71b3ab
BLAKE2b-256 9bae555a6c2ff8c73f3208c335c46c081dbb71d4287e2b956a64d6f33f3fd10c

See more details on using hashes here.

File details

Details for the file patch_conv-0.0.1b0-py3-none-any.whl.

File metadata

File hashes

Hashes for patch_conv-0.0.1b0-py3-none-any.whl
Algorithm Hash digest
SHA256 202baeda14b33c027c46e75b89739c19625410bce987d1d5023d43601d161325
MD5 96c3c07283810a8530d90081ee8b8c02
BLAKE2b-256 3bf28440ad6cb266c40a679bc92fb9c9d9d06d9a851df18336dcbfe6b17e2b2e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page