Skip to main content

3D Convolution for Apple Silicon (MPS)

Project description

MPS Conv3D

3D Convolution for Apple Silicon (M1/M2/M3/M4).

Drop-in replacement for torch.nn.functional.conv3d on MPS.

Why?

3D convolutions are essential for video models:

  • Synchformer: Audio-visual synchronization
  • I3D: Video classification
  • SlowFast: Action recognition
  • C3D: Video feature extraction
  • MMAudio: Audio generation from video

But PyTorch's MPS backend doesn't support 3D convolutions:

NotImplementedError: aten::slow_conv3d_forward is not implemented for MPS

This package provides a native Metal implementation.

Installation

pip install mps-conv3d

Or from source:

git clone https://github.com/mpsops/mps-conv3d
cd mps-conv3d
pip install -e .

Quick Start

Patch All Conv3D Operations (Recommended)

from mps_conv3d import patch_conv3d

# Patch at the start of your script
patch_conv3d()

# Now all conv3d operations use MPS!
import torch
import torch.nn.functional as F

x = torch.randn(1, 3, 16, 112, 112, device='mps')
w = torch.randn(64, 3, 3, 7, 7, device='mps')
out = F.conv3d(x, w, padding=(1, 3, 3))  # Uses MPS!

Direct Usage

import torch
from mps_conv3d import conv3d

x = torch.randn(1, 3, 16, 112, 112, device='mps')
w = torch.randn(64, 3, 3, 7, 7, device='mps')

out = conv3d(x, w, stride=1, padding=(1, 3, 3))

Conv3d Module

from mps_conv3d import Conv3d

conv = Conv3d(
    in_channels=3,
    out_channels=64,
    kernel_size=(3, 7, 7),
    stride=(1, 2, 2),
    padding=(1, 3, 3)
).to('mps')

x = torch.randn(1, 3, 16, 112, 112, device='mps')
out = conv(x)

API Reference

conv3d(input, weight, bias, stride, padding, dilation, groups)

Same signature as torch.nn.functional.conv3d.

Parameter Type Description
input Tensor Input tensor (N, C_in, D, H, W)
weight Tensor Weight tensor (C_out, C_in/groups, kD, kH, kW)
bias Tensor Optional bias (C_out,)
stride int/tuple Stride of convolution
padding int/tuple Padding added to input
dilation int/tuple Dilation of kernel
groups int Number of groups

patch_conv3d()

Monkey-patches torch.nn.functional.conv3d to use MPS implementation for MPS tensors.

unpatch_conv3d()

Restores original torch.nn.functional.conv3d.

Compatibility

  • PyTorch: 2.0+
  • macOS: 12.0+ (Monterey)
  • Hardware: Apple Silicon (M1/M2/M3/M4)

Features

  • Full forward and backward pass (training supported)
  • fp32 and fp16 supported
  • Groups and dilation supported
  • Drop-in compatible with PyTorch API

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mps_conv3d-0.1.1.tar.gz (9.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mps_conv3d-0.1.1-cp314-cp314-macosx_15_0_arm64.whl (87.7 kB view details)

Uploaded CPython 3.14macOS 15.0+ ARM64

File details

Details for the file mps_conv3d-0.1.1.tar.gz.

File metadata

  • Download URL: mps_conv3d-0.1.1.tar.gz
  • Upload date:
  • Size: 9.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for mps_conv3d-0.1.1.tar.gz
Algorithm Hash digest
SHA256 5b171ba3df3f5fb7ed56b4805191985930b2728d3c3a05679700f9dc2f0ac88d
MD5 d36ff5e87a3f3b2ff9bdf7eed268ab17
BLAKE2b-256 cca77ed4583e44ea6b55147c00be3873237bc9ad87e2f2c2e9d9af3434932efb

See more details on using hashes here.

File details

Details for the file mps_conv3d-0.1.1-cp314-cp314-macosx_15_0_arm64.whl.

File metadata

File hashes

Hashes for mps_conv3d-0.1.1-cp314-cp314-macosx_15_0_arm64.whl
Algorithm Hash digest
SHA256 9b4a1edc6f445f4c75f11f104769a00ddf8c9e80105c3c67203bcbfd21e722f8
MD5 8facf577851e45333cc4b42f04fe352b
BLAKE2b-256 c6048c64d8105056604dee3efa9f59dff7747934b7f0a78d1637b2d62abcc0be

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page