Skip to main content

Correlation layer for optical flow on Apple Silicon (MPS)

Project description

MPS Correlation

Correlation layer for optical flow on Apple Silicon (M1/M2/M3/M4).

Drop-in replacement for spatial-correlation-sampler and mmcv's correlation op.

Why?

Correlation layers are essential for optical flow estimation:

  • RAFT: State-of-the-art optical flow
  • PWC-Net: Efficient optical flow
  • FlowNet/FlowNet2: Classic deep optical flow

But existing implementations are CUDA-only. On Mac you get:

NotImplementedError: correlation not implemented for MPS

This package provides a native Metal implementation.

Installation

pip install mps-correlation

Or from source:

git clone https://github.com/mpsops/mps-correlation
cd mps-correlation
pip install -e .

Quick Start

Basic Usage

import torch
from mps_correlation import correlation

# Two feature maps from consecutive frames
fmap1 = torch.randn(1, 256, 64, 64, device='mps')
fmap2 = torch.randn(1, 256, 64, 64, device='mps')

# Compute correlation volume
corr = correlation(
    fmap1, fmap2,
    kernel_size=1,
    max_displacement=4,
    stride1=1,
    stride2=1,
    pad_size=4
)
# Output: (1, 81, 64, 64) - 81 = (2*4+1)^2 displacement channels

Correlation Module

from mps_correlation import Correlation

corr_layer = Correlation(
    kernel_size=1,
    max_displacement=4,
    stride1=1,
    stride2=1,
    pad_size=4
)

corr = corr_layer(fmap1, fmap2)

RAFT-style All-Pairs Correlation

from mps_correlation import CorrBlock

# Build correlation pyramid
corr_block = CorrBlock(fmap1, fmap2, num_levels=4, radius=4)

# Lookup at specific coordinates
coords = torch.zeros(1, 2, 64, 64, device='mps')  # (x, y) coordinates
corr_features = corr_block(coords)

API Reference

correlation(input1, input2, kernel_size, max_displacement, stride1, stride2, pad_size, is_multiply)

Parameter Type Description
input1 Tensor First feature map (N, C, H, W)
input2 Tensor Second feature map (N, C, H, W)
kernel_size int Size of correlation kernel (default: 1)
max_displacement int Maximum displacement to search (default: 4)
stride1 int Stride for input1 (default: 1)
stride2 int Stride for displacement (default: 1)
pad_size int Padding size (default: 4)
is_multiply bool Use multiplication (True) or subtraction (False)

CorrBlock

RAFT-style correlation block with pyramid and lookup.

How It Works

Correlation computes similarity between patches at different displacements:

For each position (x, y) in output:
    For each displacement (dx, dy) in [-max_disp, max_disp]:
        corr[x, y, dx, dy] = sum(fmap1[x, y, :] * fmap2[x+dx, y+dy, :])

This creates a 4D cost volume that optical flow networks use to estimate motion.

Compatibility

  • PyTorch: 2.0+
  • macOS: 12.0+ (Monterey)
  • Hardware: Apple Silicon (M1/M2/M3/M4)

Features

  • Full forward and backward pass (training supported)
  • fp32 and fp16 supported
  • Compatible with RAFT, PWC-Net, FlowNet architectures

Credits

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mps_correlation-0.1.1.tar.gz (10.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mps_correlation-0.1.1-cp314-cp314-macosx_15_0_arm64.whl (85.4 kB view details)

Uploaded CPython 3.14macOS 15.0+ ARM64

File details

Details for the file mps_correlation-0.1.1.tar.gz.

File metadata

  • Download URL: mps_correlation-0.1.1.tar.gz
  • Upload date:
  • Size: 10.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for mps_correlation-0.1.1.tar.gz
Algorithm Hash digest
SHA256 fe3722e82c2eb2ed5aea21ba6aec5c3b07e73b052adf3eb8475b7b58c59d8a95
MD5 cd7b3792e78b7bebffdbd98371c730c4
BLAKE2b-256 5eb76bee9812b14b741f555f3dac0b77632175880eefe2071e70f8c1d054e5f9

See more details on using hashes here.

File details

Details for the file mps_correlation-0.1.1-cp314-cp314-macosx_15_0_arm64.whl.

File metadata

File hashes

Hashes for mps_correlation-0.1.1-cp314-cp314-macosx_15_0_arm64.whl
Algorithm Hash digest
SHA256 0e55791d724a03490072811b19c7115563605e8639f40efc023145d607a93341
MD5 c2a1e3e8e7c4ee2fd06ef0125c205881
BLAKE2b-256 d1f10fc8646088133ddb2ee4c9263f733fd71c2c153978c19351d46f8625edd6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page