Correlation layer for optical flow on Apple Silicon (MPS)
Project description
MPS Correlation
Correlation layer for optical flow on Apple Silicon (M1/M2/M3/M4).
Drop-in replacement for spatial-correlation-sampler and mmcv's correlation op.
Why?
Correlation layers are essential for optical flow estimation:
- RAFT: State-of-the-art optical flow
- PWC-Net: Efficient optical flow
- FlowNet/FlowNet2: Classic deep optical flow
But existing implementations are CUDA-only. On Mac you get:
NotImplementedError: correlation not implemented for MPS
This package provides a native Metal implementation.
Installation
pip install mps-correlation
Or from source:
git clone https://github.com/mpsops/mps-correlation
cd mps-correlation
pip install -e .
Quick Start
Basic Usage
import torch
from mps_correlation import correlation
# Two feature maps from consecutive frames
fmap1 = torch.randn(1, 256, 64, 64, device='mps')
fmap2 = torch.randn(1, 256, 64, 64, device='mps')
# Compute correlation volume
corr = correlation(
fmap1, fmap2,
kernel_size=1,
max_displacement=4,
stride1=1,
stride2=1,
pad_size=4
)
# Output: (1, 81, 64, 64) - 81 = (2*4+1)^2 displacement channels
Correlation Module
from mps_correlation import Correlation
corr_layer = Correlation(
kernel_size=1,
max_displacement=4,
stride1=1,
stride2=1,
pad_size=4
)
corr = corr_layer(fmap1, fmap2)
RAFT-style All-Pairs Correlation
from mps_correlation import CorrBlock
# Build correlation pyramid
corr_block = CorrBlock(fmap1, fmap2, num_levels=4, radius=4)
# Lookup at specific coordinates
coords = torch.zeros(1, 2, 64, 64, device='mps') # (x, y) coordinates
corr_features = corr_block(coords)
API Reference
correlation(input1, input2, kernel_size, max_displacement, stride1, stride2, pad_size, is_multiply)
| Parameter | Type | Description |
|---|---|---|
input1 |
Tensor | First feature map (N, C, H, W) |
input2 |
Tensor | Second feature map (N, C, H, W) |
kernel_size |
int | Size of correlation kernel (default: 1) |
max_displacement |
int | Maximum displacement to search (default: 4) |
stride1 |
int | Stride for input1 (default: 1) |
stride2 |
int | Stride for displacement (default: 1) |
pad_size |
int | Padding size (default: 4) |
is_multiply |
bool | Use multiplication (True) or subtraction (False) |
CorrBlock
RAFT-style correlation block with pyramid and lookup.
How It Works
Correlation computes similarity between patches at different displacements:
For each position (x, y) in output:
For each displacement (dx, dy) in [-max_disp, max_disp]:
corr[x, y, dx, dy] = sum(fmap1[x, y, :] * fmap2[x+dx, y+dy, :])
This creates a 4D cost volume that optical flow networks use to estimate motion.
Compatibility
- PyTorch: 2.0+
- macOS: 12.0+ (Monterey)
- Hardware: Apple Silicon (M1/M2/M3/M4)
Features
- Full forward and backward pass (training supported)
- fp32 and fp16 supported
- Compatible with RAFT, PWC-Net, FlowNet architectures
Credits
- spatial-correlation-sampler - Reference implementation
- RAFT - State-of-the-art optical flow
- PWC-Net - Efficient optical flow
License
MIT
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mps_correlation-0.2.1.tar.gz.
File metadata
- Download URL: mps_correlation-0.2.1.tar.gz
- Upload date:
- Size: 14.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2474be51b294c7810b4196eb1d774d0fde4cda9d8d3116a8de293a03275ab761
|
|
| MD5 |
6cba6e098b34473dc70a82b8e1f35b47
|
|
| BLAKE2b-256 |
c0d07457a3af16cb5b0fbbe6ff5c42b2b1c7bee644d9f4310d88adfa794b7c19
|
File details
Details for the file mps_correlation-0.2.1-cp314-cp314-macosx_15_0_arm64.whl.
File metadata
- Download URL: mps_correlation-0.2.1-cp314-cp314-macosx_15_0_arm64.whl
- Upload date:
- Size: 88.8 kB
- Tags: CPython 3.14, macOS 15.0+ ARM64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9f715884e7ecdec190bc5378ba6d23d99bfaba795516691d88e3480a79257292
|
|
| MD5 |
eb561e518d295e2701b7e88690bc2d55
|
|
| BLAKE2b-256 |
72e04da6fb9b8b6901f4a19bc819054345c625ca01947cafc5be8075c5718dbb
|