Patch convolution to avoid large GPU memory usage of Conv2D
Project description
Patch Conv: Patch convolution to avoid large GPU memory usage of Conv2D [Blog]
Background
In current generative models, we usually apply convolutions over large-size activations to generate high-resolution content. However, PyTorch tends to use excessive memory for these operations, potentially leading to memory shortages even on 80GB A100 GPUs.
As shown in the figure below, memory demands for standard PyTorch convolutions drastically increase when the input size reaches 1B parameters (channel×height×width). Notably, with a kernel size of 7×7, the 80GB A100 GPUs would trigger Out of Memory (OOM) errors. Inputs exceeding 2B parameters can further cause 3×3 convolutions exhaust all the memory and that’s just for one layer! This memory bottleneck prevents users and the community from scaling up the models to produce high-quality images.
To bypass this issue and reduce memory consumption, we propose a simple and effective solution -- Patch Conv. As shown in the above figure, similar to SIGE, Patch Conv first divides the input into several smaller patches along the height dimension while keeping some overlap between them. These patches are then reorganized into the batch dimension and fed into the original convolution to produce output patches, which are then concatenated together to form the final output. Patch Conv can reduce memory usage by over 2.4×, providing a viable workaround for the limitations of current implementations.
Installation
After installing PyTorch, you can install PatchConv
from PyPI:
pip install patch_conv
or via GitHub:
pip install git+https://github.com/mit-han-lab/patch_conv.git
or locally for development:
git clone git@github.com:mit-han-lab/patch_conv.git
cd patch_conv
pip install -e .
Usage
All you need to do is use convert_model
to wrap all the Conv2d
in your PyTorch model to our PatchConv
. For example,
from patch_conv import convert_model
model = Model(...) # Your PyTorch model
model = convert_model(model, splits=4) # The only modification you need to make
with torch.no_grad():
model(...) # Run the model in the original way
Performance
Patch Conv significantly reduces memory consumption by over 2.4× across various kernel sizes and input resolutions with a marginally slower inference speed compared to vanilla convolution.
Related Projects
- MCUNetV2: Memory-Efficient Patch-based Inference for Tiny Deep Learning, Lin et al., NeurIPS 2021
- Efficient Spatially Sparse Inference for Conditional GANs and Diffusion Models, Li et al., NeurIPS 2022
- DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models, Li et al., CVPR 2024
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file patch_conv-0.0.1b0.tar.gz
.
File metadata
- Download URL: patch_conv-0.0.1b0.tar.gz
- Upload date:
- Size: 4.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.12.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3c114d79a71b68e6d99254970d29c4fb7521d020483141dc932fc46b859a6b0c |
|
MD5 | c04a4f4f13f5183a7e6b83513a71b3ab |
|
BLAKE2b-256 | 9bae555a6c2ff8c73f3208c335c46c081dbb71d4287e2b956a64d6f33f3fd10c |
File details
Details for the file patch_conv-0.0.1b0-py3-none-any.whl
.
File metadata
- Download URL: patch_conv-0.0.1b0-py3-none-any.whl
- Upload date:
- Size: 5.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.12.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 202baeda14b33c027c46e75b89739c19625410bce987d1d5023d43601d161325 |
|
MD5 | 96c3c07283810a8530d90081ee8b8c02 |
|
BLAKE2b-256 | 3bf28440ad6cb266c40a679bc92fb9c9d9d06d9a851df18336dcbfe6b17e2b2e |