Skip to main content

Unofficial PyTorch port of MULE (SF-NFNet-F0), the SiriusXM/Pandora music-audio embedding model.

Project description

mule-torch

License: GPLv3 Weights: CC BY-NC 4.0 Python PyTorch Hugging Face Status

An unofficial PyTorch port of MULE (Musicset Unsupervised Large Embedding), the SF-NFNet-F0 music-audio representation model from SiriusXM/Pandora:

Supervised and Unsupervised Learning of Audio Representations for Music Understanding, M. C. McCallum, F. Korzeniowski, S. Oramas, F. Gouyon, A. F. Ehmann. ISMIR 2022. https://arxiv.org/abs/2210.03799

This is not a re-training. It re-implements the SF-NFNet-F0 architecture in pure PyTorch and transfers the pretrained weights from the original Keras model (model.keras), verified to be numerically equivalent. The result is a clean nn.Module that is batched, GPU-native, and ONNX-exportable — none of which the original TensorFlow/SCOOCH/Analysis pipeline supports.

Disclaimer

This is an independent, community port from TensorFlow to PyTorch. It is not affiliated with, endorsed by, or maintained by SiriusXM, Pandora, or the original authors. The original model, weights, and configurations come from PandoraMedia/music-audio-representations. All credit for the model goes to the original authors — please cite their paper.

Install

pip install mule-torch                                  # once published to PyPI
# or, from source:
pip install git+https://github.com/matteospanio/mule-torch

The installed package is a pure torch library (torch, numpy, safetensors, huggingface_hub). It does not pull in TensorFlow — the conversion/verification tooling lives in standalone uv scripts (see below).

Usage

from mule_torch import MuleModel

model = MuleModel.from_pretrained()      # downloads weights from the Hugging Face Hub
emb = model(waveform)                    # waveform: (B, T) float @ 16 kHz mono -> (B, 1728)

Input is a 16 kHz mono waveform in [-1, 1]. The model computes a 96-band log-mel spectrogram, slices it into 96×300 windows every ~2 s, runs the SF-NFNet-F0 backbone, and mean-pools the per-slice 1728-d embeddings into one vector per clip — matching mule_embedding_timeline.yml + a timeline average.

Pretrained weights

The converted weights are hosted on the Hugging Face Hub at matteospanio/mule (CC BY-NC 4.0) and are downloaded automatically by from_pretrained():

MuleModel.from_pretrained()                          # default: hf_repo="matteospanio/mule"
MuleModel.from_pretrained(hf_repo="matteospanio/mule", revision="main")
MuleModel.from_pretrained(model_dir="artifacts")     # or load a local copy (skip the download)

The Hub repo ships model.safetensors + config.json (used here) and a self-contained backbone.onnx (opset 17, (N,1,96,300) log-mel slice → (N,1728), dynamic batch) for ONNX Runtime. Set $MULE_TORCH_DIR to point from_pretrained() at a local directory without passing model_dir.

How the port works

Stage Original (TF) This port (torch)
Mel front-end librosa.feature.melspectrogram MuleMelSpectrogram (fixed windowed-DFT conv1d + the librosa filterbank stored as a buffer, since torchaudio can't do norm=2.0)
Slicing SliceExtractor (numpy) slice_mel (torch)
Backbone SfNfNetF0 Keras model SfNfNetF0 nn.Module (WSConv2d, scaled GELU, squeeze-excite, NFNet blocks, fast→slow fusion)
Weights model.keras (251 MB) model.safetensors (converted)

Weight standardization is recomputed on the fly (faithful + fine-tunable); constants (β, α, scaled-activation gains) are baked into the architecture; the learnable skip-init gains are the only saved scalars per block. Stochastic depth is a no-op at inference (shortcut + residual) and is dropped.

Amplitude convention. The original AudioFile reader scales PCM16 by 1/2^16. If you feed conventional [-1,1] audio, embeddings still track the original closely but are not bit-identical because the log10(10000·x+1) mel compression is non-linear. The verification below feeds the exact waveform the reference used, so parity is exact.

Converting + verifying the weights

The conversion (TF → safetensors) and the parity check are standalone uv scripts with PEP 723 inline dependencies — no virtualenv setup, no TensorFlow in the package. Just uv run them; uv builds the right ephemeral environment (Python ≤ 3.11 for TF).

# 0) get the 251 MB Keras weights + reference code
git clone https://github.com/PandoraMedia/music-audio-representations.git references/music-audio-representations
( cd references/music-audio-representations && git lfs pull )
REF=references/music-audio-representations

# 1a) EXTRACT: model.keras -> weights.npz  (TensorFlow)
uv run scripts/convert.py extract \
    --keras $REF/supporting_data/model/model.keras --references $REF --out artifacts/weights.npz

# 1b) ASSEMBLE: weights.npz -> model.safetensors + config.json  (torch)
uv run scripts/convert.py assemble --npz artifacts/weights.npz --out artifacts

# 2) Parity: genuine TF pipeline vs the torch port, end-to-end + ONNX
uv run scripts/verify.py reference --references $REF \
    --config $REF/supporting_data/configs/mule_embedding_timeline.yml \
    --wav tests/fixtures/fixture.wav --out artifacts/ref
uv run scripts/verify.py compare --ref artifacts/ref --weights artifacts --onnx

Tests

uv pip install -e ".[dev]"
pytest -m "not requires_weights"                                  # runs anywhere
MULE_TORCH_WEIGHTS=artifacts MULE_TF_REF=artifacts/ref pytest      # incl. gated parity

Tests that need no weights (frontend exactness vs librosa, traced shapes, layer math, ONNX-backbone parity) run anywhere; the parity tests are skipped unless the two env vars above point at converted weights + reference dumps.

Verified parity

On an RTX 3070 against the genuine TF MULE pipeline:

  • Mel vs librosa: cosine 1.0000 (max-abs drift washes out after per-slice norm).
  • Backbone on reference slices: cosine 1.0000000.
  • End-to-end clip embedding vs original MULE: cosine 0.9999999.
  • ONNX backbone vs torch: max-abs < 1e-6.
  • Parameter count: 62.35M (paper: ~62.4M).

Licensing

  • Code: GPL-3.0-only (mirrors the upstream mule module). See LICENSE.
  • Converted weights: CC BY-NC 4.0 (inherited from the upstream MULE weights — non-commercial). See LICENSE.weights.

Please cite McCallum et al. (2022) if you use this. See NOTICE for provenance.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mule_torch-0.1.0.tar.gz (28.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mule_torch-0.1.0-py3-none-any.whl (23.1 kB view details)

Uploaded Python 3

File details

Details for the file mule_torch-0.1.0.tar.gz.

File metadata

  • Download URL: mule_torch-0.1.0.tar.gz
  • Upload date:
  • Size: 28.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.5

File hashes

Hashes for mule_torch-0.1.0.tar.gz
Algorithm Hash digest
SHA256 b9b082598ec298f27a211c4e1f903c5bc7ba6dfe6b1bbee812d517d6daede7f8
MD5 940cbc4c5821e9d949a174c39c30cb80
BLAKE2b-256 afd3797dd219e287625c5e61812875a607c161a63cf03c80bf360713b4e4785b

See more details on using hashes here.

File details

Details for the file mule_torch-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for mule_torch-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d94acad9caabeaef1748506071bfadbfc26818b91b5f89416e568c616a0efeff
MD5 0e6ac79a4cb3c948b35f965ceefb11f1
BLAKE2b-256 2676e7c5af1929a20b4b2540b350429eccc638748bfe4df1b92c7dcb65e5f0a1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page