Skip to main content

Deformable Convolution 2D for PyTorch on Apple Silicon (MPS)

Project description

MPS Deformable Convolution

Deformable Convolution 2D for PyTorch on Apple Silicon (M1/M2/M3/M4).

Drop-in replacement for torchvision.ops.deform_conv2d that actually works on MPS.

Why?

Deformable convolutions are used everywhere:

  • Detection: DETR, Deformable DETR, mmdetection models
  • Video: BasicVSR++, EDVR, optical flow models
  • Segmentation: Mask R-CNN with DCN backbones

But torchvision's implementation is CUDA-only. On Mac you get:

NotImplementedError: deform_conv2d not implemented for MPS

This package provides a native Metal implementation.

Installation

pip install mps-deform-conv

Or from source:

git clone https://github.com/mpsops/mps-deform-conv
cd mps-deform-conv
pip install -e .

Quick Start

Basic Usage

import torch
from mps_deform_conv import deform_conv2d

# Input: (batch, channels, height, width)
input = torch.randn(1, 64, 32, 32, device='mps')

# Weight: (out_channels, in_channels, kernel_h, kernel_w)
weight = torch.randn(64, 64, 3, 3, device='mps')

# Offset: (batch, 2 * kernel_h * kernel_w, out_h, out_w)
# 2 values (dy, dx) for each position in the 3x3 kernel
offset = torch.randn(1, 2*9, 32, 32, device='mps')

# Run deformable convolution
output = deform_conv2d(input, offset, weight, padding=(1, 1))

DeformConv2d Module

from mps_deform_conv import DeformConv2d

# Create layer
conv = DeformConv2d(
    in_channels=64,
    out_channels=128,
    kernel_size=3,
    padding=1
).to('mps')

# Forward pass requires input and offset
x = torch.randn(1, 64, 32, 32, device='mps')
offset = torch.randn(1, 2*9, 32, 32, device='mps')
output = conv(x, offset)

ModulatedDeformConv2d (DCNv2)

Includes the offset predictor - a complete conv layer replacement:

from mps_deform_conv import ModulatedDeformConv2d

# DCNv2 with internal offset/mask prediction
conv = ModulatedDeformConv2d(
    in_channels=64,
    out_channels=128,
    kernel_size=3,
    padding=1
).to('mps')

# Just pass input - offsets learned internally
x = torch.randn(1, 64, 32, 32, device='mps')
output = conv(x)

API Reference

deform_conv2d(input, offset, weight, bias, stride, padding, dilation, mask)

Functional interface matching torchvision.ops.deform_conv2d.

Parameter Type Description
input Tensor Input tensor (N, C_in, H, W)
offset Tensor Offset tensor (N, 2*K*K*groups, H_out, W_out)
weight Tensor Weight tensor (C_out, C_in/groups, K, K)
bias Tensor Optional bias (C_out,)
stride tuple Convolution stride (default: (1, 1))
padding tuple Input padding (default: (0, 0))
dilation tuple Kernel dilation (default: (1, 1))
mask Tensor Optional DCNv2 mask (N, K*K*groups, H_out, W_out)

DeformConv2d

Module wrapping deform_conv2d. Takes input and offset in forward.

ModulatedDeformConv2d

Self-contained DCNv2 module with internal offset prediction. Takes only input in forward.

How It Works

Standard convolution samples on a fixed grid:

[•] [•] [•]
[•] [x] [•]
[•] [•] [•]

Deformable convolution learns offsets to sample from arbitrary positions:

    [•]
[•]     [•]
    [x]     [•]
[•]     [•]
    [•]

This lets the network adapt its receptive field to the input content - useful for detecting objects at different scales, handling geometric transformations, etc.

Compatibility

  • PyTorch: 2.0+
  • macOS: 12.0+ (Monterey)
  • Hardware: Apple Silicon (M1/M2/M3/M4)

Features

  • Full forward and backward pass (training supported)
  • Gradients verified against torchvision (< 0.00001 error)
  • fp32 and fp16 supported
  • Grouped convolutions supported

Credits

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mps_deform_conv-0.1.1.tar.gz (17.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mps_deform_conv-0.1.1-cp314-cp314-macosx_15_0_arm64.whl (99.9 kB view details)

Uploaded CPython 3.14macOS 15.0+ ARM64

File details

Details for the file mps_deform_conv-0.1.1.tar.gz.

File metadata

  • Download URL: mps_deform_conv-0.1.1.tar.gz
  • Upload date:
  • Size: 17.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for mps_deform_conv-0.1.1.tar.gz
Algorithm Hash digest
SHA256 3f7ae1f36fcce990a7d73a748e707209af662c2ea3567f80e57d646f151c9ef9
MD5 7538b0b5f9a88e25a1a0b214a27aed53
BLAKE2b-256 7443f211b4bc52b6f6ef1f77f59c26399c6572b343f4358bafd51957b31d8592

See more details on using hashes here.

File details

Details for the file mps_deform_conv-0.1.1-cp314-cp314-macosx_15_0_arm64.whl.

File metadata

File hashes

Hashes for mps_deform_conv-0.1.1-cp314-cp314-macosx_15_0_arm64.whl
Algorithm Hash digest
SHA256 313bfefb7e0a8ba96391a0b29b5a39169c957a631ea2fc570b8df434515e114e
MD5 87d205a36adecb1205a86c6e4df2ab1e
BLAKE2b-256 20e3c9768c0fa670986c5f80c2ec10ee2dfe6ae4d242e25c3d645c98275803cc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page