Skip to main content

Unofficial PyTorch port of MULE (SF-NFNet-F0), the SiriusXM/Pandora music-audio embedding model.

Project description

mule-torch

PyPI License: GPLv3 Weights: CC BY-NC 4.0 Python PyTorch Hugging Face Status

An unofficial PyTorch port of MULE (Musicset Unsupervised Large Embedding), the SF-NFNet-F0 music-audio representation model from SiriusXM/Pandora:

Supervised and Unsupervised Learning of Audio Representations for Music Understanding, M. C. McCallum, F. Korzeniowski, S. Oramas, F. Gouyon, A. F. Ehmann. ISMIR 2022. https://arxiv.org/abs/2210.03799

This is not a re-training. It re-implements the SF-NFNet-F0 architecture in pure PyTorch and transfers the pretrained weights from the original Keras model (model.keras), verified to be numerically equivalent. The result is a clean nn.Module that is batched, GPU-native, and ONNX-exportable — none of which the original TensorFlow/SCOOCH/Analysis pipeline supports.

Disclaimer

This is an independent, community port from TensorFlow to PyTorch. It is not affiliated with, endorsed by, or maintained by SiriusXM, Pandora, or the original authors. The original model, weights, and configurations come from PandoraMedia/music-audio-representations. All credit for the model goes to the original authors — please cite their paper.

Install

pip install mule-torch
# or, latest from source:
pip install git+https://github.com/matteospanio/mule-torch

The installed package is a pure torch library (torch, numpy, safetensors, huggingface_hub). It does not pull in TensorFlow — the conversion/verification tooling lives in standalone uv scripts (see below).

Usage

from mule_torch import MuleModel

model = MuleModel.from_pretrained()      # downloads weights from the Hugging Face Hub
emb = model(waveform)                    # waveform: (B, T) float @ 16 kHz mono -> (B, 1728)

Input is a 16 kHz mono waveform in [-1, 1]. The model computes a 96-band log-mel spectrogram, slices it into 96×300 windows every ~2 s, runs the SF-NFNet-F0 backbone, and mean-pools the per-slice 1728-d embeddings into one vector per clip — matching mule_embedding_timeline.yml + a timeline average.

Pretrained weights

The converted weights are hosted on the Hugging Face Hub at matteospanio/mule (CC BY-NC 4.0) and are downloaded automatically by from_pretrained():

MuleModel.from_pretrained()                          # default: hf_repo="matteospanio/mule"
MuleModel.from_pretrained(hf_repo="matteospanio/mule", revision="main")
MuleModel.from_pretrained(model_dir="artifacts")     # or load a local copy (skip the download)

The Hub repo ships model.safetensors + config.json (used here) and a self-contained backbone.onnx (opset 17, (N,1,96,300) log-mel slice → (N,1728), dynamic batch) for ONNX Runtime. Set $MULE_TORCH_DIR to point from_pretrained() at a local directory without passing model_dir.

How the port works

Stage Original (TF) This port (torch)
Mel front-end librosa.feature.melspectrogram MuleMelSpectrogram (fixed windowed-DFT conv1d + the librosa filterbank stored as a buffer, since torchaudio can't do norm=2.0)
Slicing SliceExtractor (numpy) slice_mel (torch)
Backbone SfNfNetF0 Keras model SfNfNetF0 nn.Module (WSConv2d, scaled GELU, squeeze-excite, NFNet blocks, fast→slow fusion)
Weights model.keras (251 MB) model.safetensors (converted)

Weight standardization is recomputed on the fly (faithful + fine-tunable); constants (β, α, scaled-activation gains) are baked into the architecture; the learnable skip-init gains are the only saved scalars per block. Stochastic depth is a no-op at inference (shortcut + residual) and is dropped.

Amplitude convention. The original AudioFile reader scales PCM16 by 1/2^16. If you feed conventional [-1,1] audio, embeddings still track the original closely but are not bit-identical because the log10(10000·x+1) mel compression is non-linear. The verification below feeds the exact waveform the reference used, so parity is exact.

Converting + verifying the weights

The conversion (TF → safetensors) and the parity check are standalone uv scripts with PEP 723 inline dependencies — no virtualenv setup, no TensorFlow in the package. Just uv run them; uv builds the right ephemeral environment (Python ≤ 3.11 for TF).

# 0) get the 251 MB Keras weights + reference code
git clone https://github.com/PandoraMedia/music-audio-representations.git references/music-audio-representations
( cd references/music-audio-representations && git lfs pull )
REF=references/music-audio-representations

# 1a) EXTRACT: model.keras -> weights.npz  (TensorFlow)
uv run scripts/convert.py extract \
    --keras $REF/supporting_data/model/model.keras --references $REF --out artifacts/weights.npz

# 1b) ASSEMBLE: weights.npz -> model.safetensors + config.json  (torch)
uv run scripts/convert.py assemble --npz artifacts/weights.npz --out artifacts

# 2) Parity: genuine TF pipeline vs the torch port, end-to-end + ONNX
uv run scripts/verify.py reference --references $REF \
    --config $REF/supporting_data/configs/mule_embedding_timeline.yml \
    --wav tests/fixtures/fixture.wav --out artifacts/ref
uv run scripts/verify.py compare --ref artifacts/ref --weights artifacts --onnx

Tests

uv pip install -e ".[dev]"
pytest -m "not requires_weights"                                  # runs anywhere
MULE_TORCH_WEIGHTS=artifacts MULE_TF_REF=artifacts/ref pytest      # incl. gated parity

Tests that need no weights (frontend exactness vs librosa, traced shapes, layer math, ONNX-backbone parity) run anywhere; the parity tests are skipped unless the two env vars above point at converted weights + reference dumps.

Verified parity

On an RTX 3070 against the genuine TF MULE pipeline:

  • Mel vs librosa: cosine 1.0000 (max-abs drift washes out after per-slice norm).
  • Backbone on reference slices: cosine 1.0000000.
  • End-to-end clip embedding vs original MULE: cosine 0.9999999.
  • ONNX backbone vs torch: max-abs < 1e-6.
  • Parameter count: 62.35M (paper: ~62.4M).

Licensing

  • Code: GPL-3.0-only (mirrors the upstream mule module). See LICENSE.
  • Converted weights: CC BY-NC 4.0 (inherited from the upstream MULE weights — non-commercial). See LICENSE.weights.

Please cite McCallum et al. (2022) if you use this. See NOTICE for provenance.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mule_torch-0.2.0.tar.gz (29.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mule_torch-0.2.0-py3-none-any.whl (23.4 kB view details)

Uploaded Python 3

File details

Details for the file mule_torch-0.2.0.tar.gz.

File metadata

  • Download URL: mule_torch-0.2.0.tar.gz
  • Upload date:
  • Size: 29.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.5

File hashes

Hashes for mule_torch-0.2.0.tar.gz
Algorithm Hash digest
SHA256 406e064e83c7304173959f44a0c255adb12c9528347d9f8fd63cd2ce3690828d
MD5 56fa2341273e8457cf25b6edc190730a
BLAKE2b-256 20214482c5bf3b106f3dd37c26f812293f25b0309d6c52b3afa63ce577efe7bb

See more details on using hashes here.

File details

Details for the file mule_torch-0.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for mule_torch-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 906f2386e479b74525391cca3ba1f5699012c92a375feff5f6a44d04cb63a03f
MD5 d689dfb77dba58c4773e665bbf875dd6
BLAKE2b-256 8745b15480168bf32656007e6da925371050bcf911dc1a0d5640699c0df7526d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page