Skip to main content

Faster and more memory efficient implementation of the Partial Convolution 2D layer in PyTorch equivalent to the standard NVidia implem.

Project description

Torch Pconv

PyPI version

Faster and more memory efficient implementation of the Partial Convolution 2D layer in PyTorch equivalent to the standard Nvidia implementation.

This implementation has numerous advantages:

  1. It is strictly equivalent in computation to the reference implementation by Nvidia . I made unit tests to assess that all throughout development.
  2. It's commented and more readable
  3. It's faster and more memory efficient, which means you can use more layers on smaller GPUs. It's a good thing considering today's GPU prices.
  4. It's a PyPI-published library. You can pip install it instead of copy/pasting source code, and get the benefit of ( free) bugfixes when someone notice a bug in the implementation.

Total memory cost (in bytes)

Getting started

pip3 install torch_pconv

Usage

import torch
from torch_pconv import PConv2d

images = torch.rand(32, 3, 256, 256)
masks = (torch.rand(32, 256, 256) > 0.5).to(torch.float32)

pconv = PConv2d(
    in_channels=3,
    out_channels=64,
    kernel_size=7,
    stride=1,
    padding=2,
    dilation=2,
    bias=True
)

output, shrunk_masks = pconv(images, masks)

Performance improvement

Test

You can find the reference implementation by Nvidia here .

I tested their implementation vs mine one the following configuration:

ParameterValue
in_channels64
out_channels128
kernel_size9
stride1
padding3
biasTrue
input height/width256

The goal here was to produce the most computationally expensive partial convolution operator so that the performance difference is displayed better.

I compute both the forward and the backward pass, in case one consumes more memory than the other.

Results

Total memory cost (in bytes)

torch_pconv
Nvidia® (Guilin)
Forward only813 466 6244 228 120 576
Backward only1 588 201 4801 588 201 480
Forward + Backward2 405 797 6406 084 757 512

Development

To install the latest version from Github, run:

git clone git@github.com:DesignStripe/torch_pconv.git torch_pconv
cd torch_pconv
pip3 install -U .

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_pconv-0.1.0.tar.gz (6.6 kB view details)

Uploaded Source

Built Distribution

torch_pconv-0.1.0-py3-none-any.whl (7.0 kB view details)

Uploaded Python 3

File details

Details for the file torch_pconv-0.1.0.tar.gz.

File metadata

  • Download URL: torch_pconv-0.1.0.tar.gz
  • Upload date:
  • Size: 6.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.6.3 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.56.0 CPython/3.9.1

File hashes

Hashes for torch_pconv-0.1.0.tar.gz
Algorithm Hash digest
SHA256 d7dfb61a26ef5f2ae5bf07b349fd3c53252443ab9e83df76366a824aa378a9c9
MD5 0b18ebd3843fde30a05146b5fbf52e5a
BLAKE2b-256 27fcd641662e28f12ed27aa08d2e900f0585210fa24538cbfbf2013ecb733203

See more details on using hashes here.

File details

Details for the file torch_pconv-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: torch_pconv-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 7.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.6.3 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.56.0 CPython/3.9.1

File hashes

Hashes for torch_pconv-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 01adcaca3e432c09a4702c8212f72eebca0a9fed4c5d42000cd77dd46a67e678
MD5 dde7c740abd8f84ae862bbf18c3381b0
BLAKE2b-256 a1347106ceba8be52f0b749a8adc0d30d006f4669e9ef4e5a4eed75fad8a7fed

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page