3D Convolution for Apple Silicon (MPS)
Project description
MPS Conv3D
3D Convolution for Apple Silicon (M1/M2/M3/M4).
Drop-in replacement for torch.nn.functional.conv3d on MPS.
Why?
3D convolutions are essential for video models:
- Synchformer: Audio-visual synchronization
- I3D: Video classification
- SlowFast: Action recognition
- C3D: Video feature extraction
- MMAudio: Audio generation from video
But PyTorch's MPS backend doesn't support 3D convolutions:
NotImplementedError: aten::slow_conv3d_forward is not implemented for MPS
This package provides a native Metal implementation.
Installation
pip install mps-conv3d
Or from source:
git clone https://github.com/mpsops/mps-conv3d
cd mps-conv3d
pip install -e .
Quick Start
Patch All Conv3D Operations (Recommended)
from mps_conv3d import patch_conv3d
# Patch at the start of your script
patch_conv3d()
# Now all conv3d operations use MPS!
import torch
import torch.nn.functional as F
x = torch.randn(1, 3, 16, 112, 112, device='mps')
w = torch.randn(64, 3, 3, 7, 7, device='mps')
out = F.conv3d(x, w, padding=(1, 3, 3)) # Uses MPS!
Direct Usage
import torch
from mps_conv3d import conv3d
x = torch.randn(1, 3, 16, 112, 112, device='mps')
w = torch.randn(64, 3, 3, 7, 7, device='mps')
out = conv3d(x, w, stride=1, padding=(1, 3, 3))
Conv3d Module
from mps_conv3d import Conv3d
conv = Conv3d(
in_channels=3,
out_channels=64,
kernel_size=(3, 7, 7),
stride=(1, 2, 2),
padding=(1, 3, 3)
).to('mps')
x = torch.randn(1, 3, 16, 112, 112, device='mps')
out = conv(x)
API Reference
conv3d(input, weight, bias, stride, padding, dilation, groups)
Same signature as torch.nn.functional.conv3d.
| Parameter | Type | Description |
|---|---|---|
input |
Tensor | Input tensor (N, C_in, D, H, W) |
weight |
Tensor | Weight tensor (C_out, C_in/groups, kD, kH, kW) |
bias |
Tensor | Optional bias (C_out,) |
stride |
int/tuple | Stride of convolution |
padding |
int/tuple | Padding added to input |
dilation |
int/tuple | Dilation of kernel |
groups |
int | Number of groups |
patch_conv3d()
Monkey-patches torch.nn.functional.conv3d to use MPS implementation for MPS tensors.
unpatch_conv3d()
Restores original torch.nn.functional.conv3d.
Compatibility
- PyTorch: 2.0+
- macOS: 12.0+ (Monterey)
- Hardware: Apple Silicon (M1/M2/M3/M4)
Features
- Full forward and backward pass (training supported)
- fp32 and fp16 supported
- Groups and dilation supported
- Drop-in compatible with PyTorch API
License
MIT
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mps_conv3d-0.2.1.tar.gz.
File metadata
- Download URL: mps_conv3d-0.2.1.tar.gz
- Upload date:
- Size: 13.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
46e384b5683742d1cac968e1ae515adb90f4205a61303af281763d730e24d6a5
|
|
| MD5 |
39df6cbd15574c37f0a76706160ae9ef
|
|
| BLAKE2b-256 |
ac8155edad21c1f90eaa69b4975505b271e0403b6e53e388b5098113228621c9
|
File details
Details for the file mps_conv3d-0.2.1-cp314-cp314-macosx_15_0_arm64.whl.
File metadata
- Download URL: mps_conv3d-0.2.1-cp314-cp314-macosx_15_0_arm64.whl
- Upload date:
- Size: 93.0 kB
- Tags: CPython 3.14, macOS 15.0+ ARM64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
cfbbade8e158bf006bc6020d37cb8b34980f58f83134350fd1f899594e2f4521
|
|
| MD5 |
5cf65c75cad1946a0bdb55947660ca5c
|
|
| BLAKE2b-256 |
4186fcb930f89a914c0afe3f125e8d20ae1594d346edc50d5611d979b7b9f27a
|