Fast torch-compatible STFT and ISTFT on Apple MPS via custom Metal kernels
Project description
mps-spectro
Drop-in torch.stft / torch.istft replacements for Apple Silicon, plus mel frontends built on top of them — 1.4–3x faster on MPS via custom Metal kernels.
# before
spec = torch.stft(x, n_fft=2048, hop_length=512, window=window, center=True, return_complex=True)
y = torch.istft(spec, n_fft=2048, hop_length=512, window=window, center=True, length=T)
# after
from mps_spectro import stft, istft
spec = stft(x, n_fft=2048, hop_length=512)
y = istft(spec, n_fft=2048, hop_length=512, center=True, length=T)
Log-mel frontend for AMT / ASR-style models:
from mps_spectro import LogMelSpectrogramTransform
frontend = LogMelSpectrogramTransform(
sample_rate=16000,
n_fft=2048,
hop_length=512,
n_mels=256,
f_min=30.0,
f_max=8000.0,
pad_mode="constant",
power=1.0,
norm="slaney",
mel_scale="htk",
log_amin=1e-5,
log_mode="clamp",
)
mel = frontend(x)
Dynamic pitch/model frontends with per-call keyshift / speed, including support for externally supplied mel filterbanks:
from mps_spectro import DynamicMelSpectrogramTransform
frontend = DynamicMelSpectrogramTransform(
sample_rate=16000,
n_fft=1024,
hop_length=160,
win_length=1024,
output_scale="log",
log_amin=1e-5,
mel_basis=external_mel_basis, # optional [n_mels, n_freqs]
)
mel = frontend(x, keyshift=3, speed=1.2)
This 0.3.0 line expands mps-spectro from fast STFT/iSTFT plus fixed mel frontends into a broader shared spectral frontend package:
- standard mel frontends for log, linear, dB, and compat-style outputs
- dynamic frontends for pitch models with per-call
keyshift/speed - optional external mel filterbank injection for exact project parity
- parity-oriented dynamic STFT mode when exact legacy wrapper behavior matters more than the lowest-level fast path
Drop-in compatible with python-audio-separator (MDX, Roformer, Demucs) — 1.4x faster STFT and 2x faster iSTFT on stereo 44.1 kHz audio with no model changes. See benchmarks below.
Install
pip install mps-spectro
Features
- PyTorch-compatible STFT/ISTFT semantics (same parameters as
torch.stft/torch.istft) - PyTorch-native mel frontends built on top of the same spectral core
- Dynamic mel/spectrogram frontends for pitch models with
keyshift,speed, and optional external mel bases - Fused overlap-add with optimized Metal compute shaders
- Autograd support with custom Metal backward kernels
torch.compilecompatible (aot_eagerbackend) viatorch.librarycustom ops- Pure Python — no C++ build step, no Xcode CLI tools
Validated downstream use cases
The current package surface has been benchmarked and parity-checked in several real consumer projects:
mamba_amt: log-mel frontend replacement on MPSpython-audio-separator: shared STFT/iSTFT compatibility layerLinkSeg: compat mel frontend replacing project-local frontend codeSongFormer-mps: shared dB mel frontends for MusicFM and MuQRVMPE: dynamic mel frontend with per-callkeyshift/speedtorchfcpe: dynamic spectrogram path for the MPS mel frontend patch
The most important takeaway is that mps-spectro now covers both:
- fixed frontend replacements for torchaudio-style mel paths
- dynamic frontend building blocks for pitch models that previously needed project-local MPS STFT patches
Autograd
Both stft and istft support PyTorch autograd when inputs have requires_grad=True:
x = torch.randn(4, 16000, device="mps", requires_grad=True)
spec = stft(x, n_fft=1024, hop_length=256)
y = istft(spec, n_fft=1024, hop_length=256, center=True, length=16000)
loss = y.pow(2).sum()
loss.backward()
print(x.grad.shape) # torch.Size([4, 16000])
When requires_grad=False (the default), zero overhead -- the original Metal kernel path is used directly. Backward passes use custom Metal kernels for GPU-accelerated gradient computation. Window gradients are not computed (returns None) since windows are almost always frozen in practice.
torch.compile
Custom ops are registered via torch.library with Meta (FakeTensor) kernels, so torch.compile can trace through both forward and backward:
@torch.compile(backend="aot_eager")
def f(x):
return stft(x, n_fft=2048, hop_length=512)
f(torch.randn(4, 160000, device="mps")) # works
ISTFT extras
istft also supports:
torch_like=True-- raise on NOLA violations liketorch.istftsafety="auto"|"always"|"off"-- NOLA envelope safety checkingkernel_dtype="float32"|"float16"|"mixed"-- Metal kernel precisionkernel_layout="auto"|"native"|"transposed"-- memory layout selection
Benchmarks
Apple M4 Max, macOS 26.3, PyTorch 2.10.0, 20 iterations (5 warmup).
STFT Forward
| Config | torch MPS | mps_spectro | Speedup |
|---|---|---|---|
| B=4 T=160k nfft=1024 | 0.51 ms | 0.35 ms | 1.5x |
| B=4 T=160k nfft=2048 | 0.53 ms | 0.31 ms | 1.7x |
| B=8 T=160k nfft=1024 | 0.78 ms | 0.46 ms | 1.7x |
| B=4 T=1.3M nfft=1024 | 1.93 ms | 1.38 ms | 1.4x |
ISTFT Forward
| Config | torch MPS | mps_spectro | Speedup |
|---|---|---|---|
| B=4 T=160k nfft=1024 | 1.10 ms | 0.34 ms | 3.2x |
| B=8 T=160k nfft=1024 | 1.70 ms | 0.63 ms | 2.7x |
| B=4 T=1.3M nfft=1024 | 6.01 ms | 2.30 ms | 2.6x |
| B=1 T=1.3M nfft=1024 | 1.76 ms | 0.61 ms | 2.9x |
STFT Forward + Backward
| Config | torch MPS | mps_spectro | Speedup |
|---|---|---|---|
| B=4 T=160k nfft=1024 | 1.51 ms | 1.05 ms | 1.4x |
| B=8 T=160k nfft=1024 | 2.96 ms | 2.08 ms | 1.4x |
| B=4 T=1.3M nfft=1024 | 12.75 ms | 9.73 ms | 1.3x |
| B=1 T=1.3M nfft=1024 | 2.95 ms | 2.16 ms | 1.4x |
ISTFT Forward + Backward
| Config | torch MPS | mps_spectro | Speedup |
|---|---|---|---|
| B=4 T=160k nfft=1024 | 1.91 ms | 0.98 ms | 1.9x |
| B=8 T=160k nfft=1024 | 2.95 ms | 1.62 ms | 1.8x |
| B=4 T=1.3M nfft=1024 | 9.95 ms | 5.71 ms | 1.7x |
| B=1 T=1.3M nfft=1024 | 2.95 ms | 1.56 ms | 1.9x |
Roundtrip (STFT -> ISTFT) Forward + Backward
| Config | torch MPS | mps_spectro | Speedup |
|---|---|---|---|
| B=4 T=160k nfft=1024 | 2.52 ms | 1.47 ms | 1.7x |
| B=8 T=160k nfft=1024 | 4.71 ms | 2.55 ms | 1.8x |
| B=4 T=1.3M nfft=1024 | 18.42 ms | 11.07 ms | 1.7x |
| B=1 T=1.3M nfft=1024 | 4.60 ms | 2.39 ms | 1.9x |
To reproduce:
- Full suite:
python scripts/benchmark.py - Dispatch overhead profile:
python scripts/benchmark.py --dispatch-profile
STFT/iSTFT in audio-separator workloads
python-audio-separator uses torch.stft/torch.istft in its MDX, Roformer, and Demucs model pipelines. We swapped in mps_spectro via a compatibility layer and measured the STFT/iSTFT portion of each pipeline with two real stereo 44.1 kHz tracks (267s and 195s). Apple M4 Max, PyTorch 2.10.0, 20 iterations, 5 warmup, 5s cooldown.
| Model config | STFT speedup | iSTFT speedup |
|---|---|---|
| MDX (n_fft=2048, hop=512) | 1.40x | 2.03x |
| Roformer (n_fft=2048, hop=512) | 1.40x | 2.01x |
| Demucs (n_fft=4096, hop=1024) | 1.28x | 1.87x |
Note: total separation wall time is dominated by model inference, so E2E speedup is modest. The gains above apply to the STFT/iSTFT calls themselves.
To reproduce: python scripts/benchmark_audio_separator.py
Numerical parity. Output stems are perceptually identical — maximum float32 difference per sample is ≤ 1.83 × 10⁻⁴ (≤ 6 int16 LSBs) across all architectures:
| Model | Max abs diff (float32) | SNR (dB) | Int16 sample match |
|---|---|---|---|
| BS-Roformer-SW (6-stem) | 3.05e-05 | 91 – 100 | ≥ 99.98% |
| Mel-Roformer Karaoke | 3.05e-05 | 89 – 91 | ≥ 99.84% |
| MDX-NET Inst HQ 5 | 1.83e-04 | 55 – 64 | ≥ 99%* |
| hdemucs_mmi (shifts=0) | 4.27e-04 | 44 – 52 | ≥ 71% |
* MDX int16 diffs are symmetric ±1 LSB rounding noise with zero bias and max ±6 LSBs.
Log-mel frontend in mamba_amt
On the mamba_amt log-mel frontend configuration (16 kHz, n_fft=2048, hop=512, n_mels=256, pad_mode="constant", power=1.0, norm="slaney", mel_scale="htk"), the new LogMelSpectrogramTransform was about 2.44x faster than torchaudio.transforms.MelSpectrogram on MPS while staying numerically tight:
- torchaudio median:
0.00397 s mps-spectromedian:0.00163 s- speedup:
2.44x - max abs diff:
1.14e-4 - mean abs diff:
6.33e-6
Dynamic frontends in RVMPE and torchfcpe
The new dynamic frontend APIs were validated against the prior project-local MPS paths:
-
RVMPEdynamic mel frontend:- old median:
1.436 ms - new shared path:
1.182 ms - speedup:
1.21x - parity: max abs
4.77e-07, mean abs1.90e-08
- old median:
-
torchfcpedynamic spectrogram path:- old median:
3.347 ms - new shared path:
3.249 ms - speedup:
1.03x - parity on realistic mel-style positive filterbanks stayed effectively exact, with max abs
3.58e-07and mean abs1.79e-08after log compression
- old median:
Using MLX instead of PyTorch?
See mlx-spectro — same idea, built natively on MLX with even faster kernels (2–8x vs torch).
How it works
Metal shader source is compiled at runtime via torch.mps.compile_shader (pure Python, no C++ build step).
- STFT: a tiled Metal kernel loads overlapping signal chunks into threadgroup shared memory (~3x data reuse for typical n_fft/hop ratios), applies reflect-padding and windowing in one pass, then
torch.fft.rfftfor the FFT - ISTFT:
torch.fft.irffton MPS, then a fused Metal kernel for synthesis-window multiply + overlap-add + envelope normalization
Requirements
- macOS with Apple Silicon (MPS)
- Python 3.12+
- PyTorch 2.10+
Tests
pip install -e ".[dev]"
pytest
License
MIT
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mps_spectro-0.3.0.tar.gz.
File metadata
- Download URL: mps_spectro-0.3.0.tar.gz
- Upload date:
- Size: 32.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0ad04786be368d1b319750446650883de121e0d7515ec51823496d7899dfb81b
|
|
| MD5 |
257380c913a76cb74237831b0da49379
|
|
| BLAKE2b-256 |
e115462bec5dae3873f20d24ad5390abdeb41ba923bfa3948353dbd570aaf3d5
|
Provenance
The following attestation bundles were made for mps_spectro-0.3.0.tar.gz:
Publisher:
release-pypi.yml on ssmall256/mps-spectro
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
mps_spectro-0.3.0.tar.gz -
Subject digest:
0ad04786be368d1b319750446650883de121e0d7515ec51823496d7899dfb81b - Sigstore transparency entry: 1096868804
- Sigstore integration time:
-
Permalink:
ssmall256/mps-spectro@31340a494ea02a19707873c0362f96576a3bb884 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/ssmall256
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release-pypi.yml@31340a494ea02a19707873c0362f96576a3bb884 -
Trigger Event:
workflow_dispatch
-
Statement type:
File details
Details for the file mps_spectro-0.3.0-py3-none-any.whl.
File metadata
- Download URL: mps_spectro-0.3.0-py3-none-any.whl
- Upload date:
- Size: 23.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
15d746f0dab6388039dc07bbd4da9b9b162776f9b13de2a9afeef3bcdf807a31
|
|
| MD5 |
b0f7cf2e456663f0d38caef9cbf2dee7
|
|
| BLAKE2b-256 |
009977ef59b93c4c9250d974a2675d9075c712c17be5b66b31471a32a2c533ae
|
Provenance
The following attestation bundles were made for mps_spectro-0.3.0-py3-none-any.whl:
Publisher:
release-pypi.yml on ssmall256/mps-spectro
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
mps_spectro-0.3.0-py3-none-any.whl -
Subject digest:
15d746f0dab6388039dc07bbd4da9b9b162776f9b13de2a9afeef3bcdf807a31 - Sigstore transparency entry: 1096868809
- Sigstore integration time:
-
Permalink:
ssmall256/mps-spectro@31340a494ea02a19707873c0362f96576a3bb884 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/ssmall256
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release-pypi.yml@31340a494ea02a19707873c0362f96576a3bb884 -
Trigger Event:
workflow_dispatch
-
Statement type: