Skip to main content

Fast torch-compatible STFT and ISTFT on Apple MPS via custom Metal kernels

Project description

mps-spectro

Drop-in torch.stft / torch.istft replacements for Apple Silicon, plus mel frontends built on top of them — 1.4–3x faster on MPS via custom Metal kernels.

# before
spec = torch.stft(x, n_fft=2048, hop_length=512, window=window, center=True, return_complex=True)
y = torch.istft(spec, n_fft=2048, hop_length=512, window=window, center=True, length=T)

# after
from mps_spectro import stft, istft

spec = stft(x, n_fft=2048, hop_length=512)
y = istft(spec, n_fft=2048, hop_length=512, center=True, length=T)

Log-mel frontend for AMT / ASR-style models:

from mps_spectro import LogMelSpectrogramTransform

frontend = LogMelSpectrogramTransform(
    sample_rate=16000,
    n_fft=2048,
    hop_length=512,
    n_mels=256,
    f_min=30.0,
    f_max=8000.0,
    pad_mode="constant",
    power=1.0,
    norm="slaney",
    mel_scale="htk",
    log_amin=1e-5,
    log_mode="clamp",
)

mel = frontend(x)

Dynamic pitch/model frontends with per-call keyshift / speed, including support for externally supplied mel filterbanks:

from mps_spectro import DynamicMelSpectrogramTransform

frontend = DynamicMelSpectrogramTransform(
    sample_rate=16000,
    n_fft=1024,
    hop_length=160,
    win_length=1024,
    output_scale="log",
    log_amin=1e-5,
    mel_basis=external_mel_basis,  # optional [n_mels, n_freqs]
)

mel = frontend(x, keyshift=3, speed=1.2)

This 0.3.0 line expands mps-spectro from fast STFT/iSTFT plus fixed mel frontends into a broader shared spectral frontend package:

  • standard mel frontends for log, linear, dB, and compat-style outputs
  • dynamic frontends for pitch models with per-call keyshift / speed
  • optional external mel filterbank injection for exact project parity
  • parity-oriented dynamic STFT mode when exact legacy wrapper behavior matters more than the lowest-level fast path

Drop-in compatible with python-audio-separator (MDX, Roformer, Demucs) — 1.4x faster STFT and 2x faster iSTFT on stereo 44.1 kHz audio with no model changes. See benchmarks below.

Install

pip install mps-spectro

Features

  • PyTorch-compatible STFT/ISTFT semantics (same parameters as torch.stft / torch.istft)
  • PyTorch-native mel frontends built on top of the same spectral core
  • Dynamic mel/spectrogram frontends for pitch models with keyshift, speed, and optional external mel bases
  • Fused overlap-add with optimized Metal compute shaders
  • Autograd support with custom Metal backward kernels
  • torch.compile compatible (aot_eager backend) via torch.library custom ops
  • Pure Python — no C++ build step, no Xcode CLI tools

Validated downstream use cases

The current package surface has been benchmarked and parity-checked in several real consumer projects:

  • mamba_amt: log-mel frontend replacement on MPS
  • python-audio-separator: shared STFT/iSTFT compatibility layer
  • LinkSeg: compat mel frontend replacing project-local frontend code
  • SongFormer-mps: shared dB mel frontends for MusicFM and MuQ
  • RVMPE: dynamic mel frontend with per-call keyshift / speed
  • torchfcpe: dynamic spectrogram path for the MPS mel frontend patch

The most important takeaway is that mps-spectro now covers both:

  • fixed frontend replacements for torchaudio-style mel paths
  • dynamic frontend building blocks for pitch models that previously needed project-local MPS STFT patches

Autograd

Both stft and istft support PyTorch autograd when inputs have requires_grad=True:

x = torch.randn(4, 16000, device="mps", requires_grad=True)

spec = stft(x, n_fft=1024, hop_length=256)
y = istft(spec, n_fft=1024, hop_length=256, center=True, length=16000)

loss = y.pow(2).sum()
loss.backward()
print(x.grad.shape)  # torch.Size([4, 16000])

When requires_grad=False (the default), zero overhead -- the original Metal kernel path is used directly. Backward passes use custom Metal kernels for GPU-accelerated gradient computation. Window gradients are not computed (returns None) since windows are almost always frozen in practice.

torch.compile

Custom ops are registered via torch.library with Meta (FakeTensor) kernels, so torch.compile can trace through both forward and backward:

@torch.compile(backend="aot_eager")
def f(x):
    return stft(x, n_fft=2048, hop_length=512)

f(torch.randn(4, 160000, device="mps"))  # works

ISTFT extras

istft also supports:

  • torch_like=True -- raise on NOLA violations like torch.istft
  • safety="auto"|"always"|"off" -- NOLA envelope safety checking
  • kernel_dtype="float32"|"float16"|"mixed" -- Metal kernel precision
  • kernel_layout="auto"|"native"|"transposed" -- memory layout selection

Benchmarks

Apple M4 Max, macOS 26.3, PyTorch 2.10.0, 20 iterations (5 warmup).

STFT Forward

Config torch MPS mps_spectro Speedup
B=4 T=160k nfft=1024 0.51 ms 0.35 ms 1.5x
B=4 T=160k nfft=2048 0.53 ms 0.31 ms 1.7x
B=8 T=160k nfft=1024 0.78 ms 0.46 ms 1.7x
B=4 T=1.3M nfft=1024 1.93 ms 1.38 ms 1.4x

ISTFT Forward

Config torch MPS mps_spectro Speedup
B=4 T=160k nfft=1024 1.10 ms 0.34 ms 3.2x
B=8 T=160k nfft=1024 1.70 ms 0.63 ms 2.7x
B=4 T=1.3M nfft=1024 6.01 ms 2.30 ms 2.6x
B=1 T=1.3M nfft=1024 1.76 ms 0.61 ms 2.9x

STFT Forward + Backward

Config torch MPS mps_spectro Speedup
B=4 T=160k nfft=1024 1.51 ms 1.05 ms 1.4x
B=8 T=160k nfft=1024 2.96 ms 2.08 ms 1.4x
B=4 T=1.3M nfft=1024 12.75 ms 9.73 ms 1.3x
B=1 T=1.3M nfft=1024 2.95 ms 2.16 ms 1.4x

ISTFT Forward + Backward

Config torch MPS mps_spectro Speedup
B=4 T=160k nfft=1024 1.91 ms 0.98 ms 1.9x
B=8 T=160k nfft=1024 2.95 ms 1.62 ms 1.8x
B=4 T=1.3M nfft=1024 9.95 ms 5.71 ms 1.7x
B=1 T=1.3M nfft=1024 2.95 ms 1.56 ms 1.9x

Roundtrip (STFT -> ISTFT) Forward + Backward

Config torch MPS mps_spectro Speedup
B=4 T=160k nfft=1024 2.52 ms 1.47 ms 1.7x
B=8 T=160k nfft=1024 4.71 ms 2.55 ms 1.8x
B=4 T=1.3M nfft=1024 18.42 ms 11.07 ms 1.7x
B=1 T=1.3M nfft=1024 4.60 ms 2.39 ms 1.9x

To reproduce:

  • Full suite: python scripts/benchmark.py
  • Dispatch overhead profile: python scripts/benchmark.py --dispatch-profile

STFT/iSTFT in audio-separator workloads

python-audio-separator uses torch.stft/torch.istft in its MDX, Roformer, and Demucs model pipelines. We swapped in mps_spectro via a compatibility layer and measured the STFT/iSTFT portion of each pipeline with two real stereo 44.1 kHz tracks (267s and 195s). Apple M4 Max, PyTorch 2.10.0, 20 iterations, 5 warmup, 5s cooldown.

Model config STFT speedup iSTFT speedup
MDX (n_fft=2048, hop=512) 1.40x 2.03x
Roformer (n_fft=2048, hop=512) 1.40x 2.01x
Demucs (n_fft=4096, hop=1024) 1.28x 1.87x

Note: total separation wall time is dominated by model inference, so E2E speedup is modest. The gains above apply to the STFT/iSTFT calls themselves.

To reproduce: python scripts/benchmark_audio_separator.py

Numerical parity. Output stems are perceptually identical — maximum float32 difference per sample is ≤ 1.83 × 10⁻⁴ (≤ 6 int16 LSBs) across all architectures:

Model Max abs diff (float32) SNR (dB) Int16 sample match
BS-Roformer-SW (6-stem) 3.05e-05 91 – 100 ≥ 99.98%
Mel-Roformer Karaoke 3.05e-05 89 – 91 ≥ 99.84%
MDX-NET Inst HQ 5 1.83e-04 55 – 64 ≥ 99%*
hdemucs_mmi (shifts=0) 4.27e-04 44 – 52 ≥ 71%

* MDX int16 diffs are symmetric ±1 LSB rounding noise with zero bias and max ±6 LSBs.

Log-mel frontend in mamba_amt

On the mamba_amt log-mel frontend configuration (16 kHz, n_fft=2048, hop=512, n_mels=256, pad_mode="constant", power=1.0, norm="slaney", mel_scale="htk"), the new LogMelSpectrogramTransform was about 2.44x faster than torchaudio.transforms.MelSpectrogram on MPS while staying numerically tight:

  • torchaudio median: 0.00397 s
  • mps-spectro median: 0.00163 s
  • speedup: 2.44x
  • max abs diff: 1.14e-4
  • mean abs diff: 6.33e-6

Dynamic frontends in RVMPE and torchfcpe

The new dynamic frontend APIs were validated against the prior project-local MPS paths:

  • RVMPE dynamic mel frontend:

    • old median: 1.436 ms
    • new shared path: 1.182 ms
    • speedup: 1.21x
    • parity: max abs 4.77e-07, mean abs 1.90e-08
  • torchfcpe dynamic spectrogram path:

    • old median: 3.347 ms
    • new shared path: 3.249 ms
    • speedup: 1.03x
    • parity on realistic mel-style positive filterbanks stayed effectively exact, with max abs 3.58e-07 and mean abs 1.79e-08 after log compression

Using MLX instead of PyTorch?

See mlx-spectro — same idea, built natively on MLX with even faster kernels (2–8x vs torch).

How it works

Metal shader source is compiled at runtime via torch.mps.compile_shader (pure Python, no C++ build step).

  1. STFT: a tiled Metal kernel loads overlapping signal chunks into threadgroup shared memory (~3x data reuse for typical n_fft/hop ratios), applies reflect-padding and windowing in one pass, then torch.fft.rfft for the FFT
  2. ISTFT: torch.fft.irfft on MPS, then a fused Metal kernel for synthesis-window multiply + overlap-add + envelope normalization

Requirements

  • macOS with Apple Silicon (MPS)
  • Python 3.12+
  • PyTorch 2.10+

Tests

pip install -e ".[dev]"
pytest

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mps_spectro-0.3.0.tar.gz (32.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mps_spectro-0.3.0-py3-none-any.whl (23.8 kB view details)

Uploaded Python 3

File details

Details for the file mps_spectro-0.3.0.tar.gz.

File metadata

  • Download URL: mps_spectro-0.3.0.tar.gz
  • Upload date:
  • Size: 32.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for mps_spectro-0.3.0.tar.gz
Algorithm Hash digest
SHA256 0ad04786be368d1b319750446650883de121e0d7515ec51823496d7899dfb81b
MD5 257380c913a76cb74237831b0da49379
BLAKE2b-256 e115462bec5dae3873f20d24ad5390abdeb41ba923bfa3948353dbd570aaf3d5

See more details on using hashes here.

Provenance

The following attestation bundles were made for mps_spectro-0.3.0.tar.gz:

Publisher: release-pypi.yml on ssmall256/mps-spectro

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file mps_spectro-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: mps_spectro-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 23.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for mps_spectro-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 15d746f0dab6388039dc07bbd4da9b9b162776f9b13de2a9afeef3bcdf807a31
MD5 b0f7cf2e456663f0d38caef9cbf2dee7
BLAKE2b-256 009977ef59b93c4c9250d974a2675d9075c712c17be5b66b31471a32a2c533ae

See more details on using hashes here.

Provenance

The following attestation bundles were made for mps_spectro-0.3.0-py3-none-any.whl:

Publisher: release-pypi.yml on ssmall256/mps-spectro

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page