Skip to main content

Torchvision+ Deformable Convolutional Networks

Project description

Torchvision+ Deformable Convolution Networks

GitHub Workflow Status PyPI Downloads GitHub

This package contains the PyTorch implementations of the Deformable Convolution operation (the commonly used torchvision.ops.deform_conv2d) proposed in https://arxiv.org/abs/1811.11168, and the Transposed Deformable Convolution proposed in https://arxiv.org/abs/2210.09446 (currently without interpolation kernel scaling). It also supports their 1D and 3D equivalences, which are not available in torchvision (thus the name).

Highlights

  • Supported operators: (All are implemented in C++/Cuda)

    • tvdcn.ops.deform_conv1d
    • tvdcn.ops.deform_conv2d (faster than torchvision.ops.deform_conv2d by at least 10% during forward pass on our Quadro RTX 5000 according to this test)
    • tvdcn.ops.deform_conv3d
    • tvdcn.ops.deform_conv_transpose1d
    • tvdcn.ops.deform_conv_transpose2d
    • tvdcn.ops.deform_conv_transpose3d
  • And the following supplementary operators (mask activation proposed in https://arxiv.org/abs/2211.05778):

    • tvdcn.ops.mask_softmax1d
    • tvdcn.ops.mask_softmax2d
    • tvdcn.ops.mask_softmax3d
  • Both offset and mask can be turned off, and can be applied in separate groups.

  • All the nn.Module wrappers for these operators are implemented, everything is @torch.jit.script-able! Please check Usage.

Note: We don't care much about onnx exportation, but if you do, you can check this repo: https://github.com/masamitsu-murase/deform_conv2d_onnx_exporter.

Requirements

  • torch>=2.1.0 (torch>=1.9.0 if installed from source)

Installation

From PyPI:

tvdcn provides some prebuilt wheels on PyPI. Run this command to install:

pip install tvdcn

Since PyTorch is migrating to Cuda 12 versions, our Linux and Windows wheels are built with Cuda 12.1 and won't be compatible with older versions.

Linux/Windows MacOS
Python version: 3.8-3.11 3.8-3.11
PyTorch version: torch==2.1.0 torch==2.1.0
Cuda version: 12.1 -
GPU CCs: 5.0,6.0,6.1,7.0,7.5,8.0,8.6,9.0+PTX -

When the Cuda versions of torch and tvdcn mismatch, you will see an error like this:

RuntimeError: Detected that PyTorch and Extension were compiled with different CUDA versions.
PyTorch has CUDA Version=11.8 and Extension has CUDA Version=12.1.
Please reinstall the Extension that matches your PyTorch install.

If you see this error instead, that means there are other issues related to Python, PyTorch, device arch, e.t.c. Please proceed to instructions to build from source, all steps are super easy.

RuntimeError: Couldn't load custom C++ ops. Recompile C++ extension with:
     python setup.py build_ext --inplace

From Source:

For installing from source, you need a C++ compiler (gcc/msvc) and a Cuda compiler (nvcc) with C++17 features enabled. Clone this repo and execute the following command:

pip install .

Or just compile the binary for inplace usage:

python setup.py build_ext --inplace

A binary (.so file for Unix and .pyd file for Windows) should be compiled inside the tvdcn folder. To check if installation is successful, try:

import tvdcn

print('Library loaded successfully:', tvdcn.has_ops())
print('Compiled with Cuda:', tvdcn.with_cuda())

Note: We use soft Cuda version compatibility checking between the built binary and the installed PyTorch, which means only major version matching is required. However, we suggest building the binaries with the same Cuda version with installed PyTorch's Cuda version to prevent any possible conflict.

Usage

Operators:

Functionally, the package offers 6 functions (listed in Highlights) much similar to torchvision.ops.deform_conv2d. However, the order of parameters is slightly different, so be cautious (check this comparison).

torchvision tvdcn
import torch, torchvision

input = torch.rand(4, 3, 10, 10)
kh, kw = 3, 3
weight = torch.rand(5, 3, kh, kw)
offset = torch.rand(4, 2 * kh * kw, 8, 8)
mask = torch.rand(4, kh * kw, 8, 8)
bias = torch.rand(5)

output = torchvision.ops.deform_conv2d(input, offset, weight, bias,
                                       stride=(1, 1),
                                       padding=(0, 0),
                                       dilation=(1, 1),
                                       mask=mask)
print(output)
import torch, tvdcn

input = torch.rand(4, 3, 10, 10)
kh, kw = 3, 3
weight = torch.rand(5, 3, kh, kw)
offset = torch.rand(4, 2 * kh * kw, 8, 8)
mask = torch.rand(4, kh * kw, 8, 8)
bias = torch.rand(5)

output = tvdcn.ops.deform_conv2d(input, weight, offset, mask, bias,
                                 stride=(1, 1),
                                 padding=(0, 0),
                                 dilation=(1, 1),
                                 groups=1)
print(output)

Specifically, the signatures of deform_conv2d and deform_conv_transpose2d look like these:

def deform_conv2d(
        input: Tensor,
        weight: Tensor,
        offset: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        bias: Optional[Tensor] = None,
        stride: Union[int, Tuple[int, int]] = 1,
        padding: Union[int, Tuple[int, int]] = 0,
        dilation: Union[int, Tuple[int, int]] = 1,
        groups: int = 1) -> Tensor:
    ...


def deform_conv_transpose2d(
        input: Tensor,
        weight: Tensor,
        offset: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        bias: Optional[Tensor] = None,
        stride: Union[int, Tuple[int, int]] = 1,
        padding: Union[int, Tuple[int, int]] = 0,
        output_padding: Union[int, Tuple[int, int]] = 0,
        dilation: Union[int, Tuple[int, int]] = 1,
        groups: int = 1) -> Tensor:
    ...

If offset=None and mask=None, the executed operators are identical to conventional convolution.

Neural Network Layers:

The nn.Module wrappers are:

  • tvdcn.ops.DeformConv1d
  • tvdcn.ops.DeformConv2d
  • tvdcn.ops.DeformConv3d
  • tvdcn.ops.DeformConvTranspose1d
  • tvdcn.ops.DeformConvTranspose2d
  • tvdcn.ops.DeformConvTranspose3d

They are subclasses of the torch.nn.modules._ConvNd, but you have to specify offset and optionally mask as extra inputs for the forward function. For example:

import torch

from tvdcn import DeformConv2d

input = torch.rand(2, 3, 64, 64)
offset = torch.rand(2, 2 * 3 * 3, 62, 62)
# if mask is None, perform the original deform_conv without modulation (v2)
mask = torch.rand(2, 1 * 3 * 3, 62, 62)

conv = DeformConv2d(3, 16, kernel_size=(3, 3))

output = conv(input, offset, mask)
print(output.shape)

Additionally, following many other implementations out there, we also implemented the packed wrappers:

  • tvdcn.ops.PackedDeformConv1d
  • tvdcn.ops.PackedDeformConv2d
  • tvdcn.ops.PackedDeformConv3d
  • tvdcn.ops.PackedDeformConvTranspose1d
  • tvdcn.ops.PackedDeformConvTranspose2d
  • tvdcn.ops.PackedDeformConvTranspose3d

These are easy-to-use classes that contain ordinary convolution layers with appropriate hyperparameters to generate offset (and mask if initialized with modulated=True); but that means less customization. The only tunable hyperparameters that effect these supplementary conv layers are offset_groups and mask_groups, which have been decoupled from and behave somewhat similar to groups.

To use the softmax activation for mask proposed in Deformable Convolution v3, set mask_activation='softmax'. offset_activation and mask_activation also accept any nn.Module.

import torch

from tvdcn import PackedDeformConv1d

input = torch.rand(2, 3, 128)

conv = PackedDeformConv1d(3, 16,
                          kernel_size=5,
                          modulated=True,
                          mask_activation='softmax')
# jit scripting
scripted_conv = torch.jit.script(conv)
print(scripted_conv)

output = scripted_conv(input)
print(output.shape)

Note: For transposed packed modules, we are generating offset and mask with pointwise convolution as we haven't found a better way to do it.

Do check the examples folder, maybe you can find something helpful.

Acknowledgements

This for fun project is directly modified and extended from torchvision.ops.deform_conv2d.

License

The code is released under the MIT license. See LICENSE.txt for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tvdcn-0.5.0.tar.gz (80.8 kB view details)

Uploaded Source

Built Distributions

tvdcn-0.5.0-cp311-cp311-win_amd64.whl (8.8 MB view details)

Uploaded CPython 3.11 Windows x86-64

tvdcn-0.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (32.2 MB view details)

Uploaded CPython 3.11 manylinux: glibc 2.17+ x86-64

tvdcn-0.5.0-cp311-cp311-macosx_10_9_x86_64.whl (479.5 kB view details)

Uploaded CPython 3.11 macOS 10.9+ x86-64

tvdcn-0.5.0-cp310-cp310-win_amd64.whl (8.8 MB view details)

Uploaded CPython 3.10 Windows x86-64

tvdcn-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (32.2 MB view details)

Uploaded CPython 3.10 manylinux: glibc 2.17+ x86-64

tvdcn-0.5.0-cp310-cp310-macosx_10_9_x86_64.whl (479.5 kB view details)

Uploaded CPython 3.10 macOS 10.9+ x86-64

tvdcn-0.5.0-cp39-cp39-win_amd64.whl (8.8 MB view details)

Uploaded CPython 3.9 Windows x86-64

tvdcn-0.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (32.2 MB view details)

Uploaded CPython 3.9 manylinux: glibc 2.17+ x86-64

tvdcn-0.5.0-cp39-cp39-macosx_10_9_x86_64.whl (479.5 kB view details)

Uploaded CPython 3.9 macOS 10.9+ x86-64

tvdcn-0.5.0-cp38-cp38-win_amd64.whl (8.8 MB view details)

Uploaded CPython 3.8 Windows x86-64

tvdcn-0.5.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (32.2 MB view details)

Uploaded CPython 3.8 manylinux: glibc 2.17+ x86-64

tvdcn-0.5.0-cp38-cp38-macosx_10_9_x86_64.whl (479.5 kB view details)

Uploaded CPython 3.8 macOS 10.9+ x86-64

File details

Details for the file tvdcn-0.5.0.tar.gz.

File metadata

  • Download URL: tvdcn-0.5.0.tar.gz
  • Upload date:
  • Size: 80.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.18

File hashes

Hashes for tvdcn-0.5.0.tar.gz
Algorithm Hash digest
SHA256 d74e4b917c4e84fa6becf46fee4d6fd6aae4dc884079f1a8fd284e0118cd899d
MD5 2b9d7be1cfa45bf2b2c3dca7bef06514
BLAKE2b-256 6a967d947180f05248932878211ed665ead5d00ac2bf9320d8a4786913d59a8e

See more details on using hashes here.

File details

Details for the file tvdcn-0.5.0-cp311-cp311-win_amd64.whl.

File metadata

  • Download URL: tvdcn-0.5.0-cp311-cp311-win_amd64.whl
  • Upload date:
  • Size: 8.8 MB
  • Tags: CPython 3.11, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.18

File hashes

Hashes for tvdcn-0.5.0-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 369a6168abbb1086b3a83ebaef13d3808769b28323954775fdb12055ddb90f1a
MD5 db466d16ea4629867a8952b19e909553
BLAKE2b-256 8f1a1fef6e57f3001f1df0392ef92834210e3849d2e85939015706bcc82fd1c4

See more details on using hashes here.

File details

Details for the file tvdcn-0.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for tvdcn-0.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 ef09d49196a89f119f5237b92552a35ef12aa1ac5476864c1f5b16b3274c67bf
MD5 a91509a680bd74b2782ac8f6b973e863
BLAKE2b-256 7926637f4556ad4956faead3b6a60c780d2ec1febab38faaef41f9f594cb9b24

See more details on using hashes here.

File details

Details for the file tvdcn-0.5.0-cp311-cp311-macosx_10_9_x86_64.whl.

File metadata

File hashes

Hashes for tvdcn-0.5.0-cp311-cp311-macosx_10_9_x86_64.whl
Algorithm Hash digest
SHA256 d7c37281c7a70ed375696248fa3c36515150d7bece76943adc6d743041628c84
MD5 ab5c10895ef6e6dc0ac46ef84a6691cf
BLAKE2b-256 af91e30719ef0979746ce552f66ade8eca1525b6954176b30ce2f450a5aa08fc

See more details on using hashes here.

File details

Details for the file tvdcn-0.5.0-cp310-cp310-win_amd64.whl.

File metadata

  • Download URL: tvdcn-0.5.0-cp310-cp310-win_amd64.whl
  • Upload date:
  • Size: 8.8 MB
  • Tags: CPython 3.10, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.18

File hashes

Hashes for tvdcn-0.5.0-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 67130e2a83735f71b1d3f02cbbdb2c8f9f3947f2a012440570d62b983da6cad1
MD5 da2923a24606dcb2abfc09dea0aa560b
BLAKE2b-256 5752cfa34d59a21dd1c87831116eadc6fb801b28cbf829243b4f9147d5cb09fd

See more details on using hashes here.

File details

Details for the file tvdcn-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for tvdcn-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 856e5a9bb231f47da742ce7e949524e24e041f6f3ffa23496e6ed671bf3dff8f
MD5 370ee0de3adb91cc9055e827c5a579aa
BLAKE2b-256 90823e7d848d02b3fd10eb04f6a145f967a3f25f6cabf1fb5cc688cb95a46319

See more details on using hashes here.

File details

Details for the file tvdcn-0.5.0-cp310-cp310-macosx_10_9_x86_64.whl.

File metadata

File hashes

Hashes for tvdcn-0.5.0-cp310-cp310-macosx_10_9_x86_64.whl
Algorithm Hash digest
SHA256 f58a8affba5c62ff3b4dd266b8b40cc1ebdd5c29d33eda77c4efe2a4388a0671
MD5 9ceece4044f16dc5efbaa052210ea02a
BLAKE2b-256 c0acee17219d7bcd45a3cf20f17c0be29c7c85a7b91e19d50e2f57f97c1af710

See more details on using hashes here.

File details

Details for the file tvdcn-0.5.0-cp39-cp39-win_amd64.whl.

File metadata

  • Download URL: tvdcn-0.5.0-cp39-cp39-win_amd64.whl
  • Upload date:
  • Size: 8.8 MB
  • Tags: CPython 3.9, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.18

File hashes

Hashes for tvdcn-0.5.0-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 6c6ce0e9c70725ade60c7597af547f2e0ec40928f17ee66eeea145b0f6340286
MD5 821e9b15d927d4c18a20ed512757e9cb
BLAKE2b-256 54a78b9cdb2ece9456b5c470f476f0031e3f80de6fb4d043282eed0268f6c308

See more details on using hashes here.

File details

Details for the file tvdcn-0.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for tvdcn-0.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 d42e6e575968678777a58007d2165c76f320f55b51aa9ca49c47467eb6d2c431
MD5 b51b946fbc3abb48f6a0767dfdd32680
BLAKE2b-256 e41fb0c190ae374e30d3149225bfa19ac008347a958fd81887a0760842d1ca48

See more details on using hashes here.

File details

Details for the file tvdcn-0.5.0-cp39-cp39-macosx_10_9_x86_64.whl.

File metadata

File hashes

Hashes for tvdcn-0.5.0-cp39-cp39-macosx_10_9_x86_64.whl
Algorithm Hash digest
SHA256 e5ed81dc1cbf4f01de9ac8996a8ee81a11309701a38e031c9bd120d6d6a6297c
MD5 191c06de323bff66556427cbc0bbc55a
BLAKE2b-256 4169e34a39e4e0808579629fac22954557005ef445bb99af6959ca5b3b618915

See more details on using hashes here.

File details

Details for the file tvdcn-0.5.0-cp38-cp38-win_amd64.whl.

File metadata

  • Download URL: tvdcn-0.5.0-cp38-cp38-win_amd64.whl
  • Upload date:
  • Size: 8.8 MB
  • Tags: CPython 3.8, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.18

File hashes

Hashes for tvdcn-0.5.0-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 a909eaca87470b27f1594a9d372320797b877aea7434b5cf1f236fe488a87505
MD5 c5208eea3c275e96bba57c8828a31514
BLAKE2b-256 a32b52909fe1561c66a1cdfa297cb7d9fab46f0987297910076b291d53b5c171

See more details on using hashes here.

File details

Details for the file tvdcn-0.5.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for tvdcn-0.5.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 57b02723627659e728a9fa8271e139711ffc597c2efa43550361f659a7c4be48
MD5 2209755527c2b1550a72ea8c52181592
BLAKE2b-256 be854b3b4fcde8150ea6c9cb0c8de25e30c3cd7590bfea417b52c8b2546e06af

See more details on using hashes here.

File details

Details for the file tvdcn-0.5.0-cp38-cp38-macosx_10_9_x86_64.whl.

File metadata

File hashes

Hashes for tvdcn-0.5.0-cp38-cp38-macosx_10_9_x86_64.whl
Algorithm Hash digest
SHA256 bbfc88e01f27ebc22aff9ac4c23e30c5f04c0c0994dbf409f0e733510c5e2368
MD5 6b253ac14e771c3a130e3b0798b2a046
BLAKE2b-256 832b96d2da7f284ec7a80961609c1ce8756a2ac8341aa2c39b6fc210f38793af

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page