Skip to main content

logWMSE is an audio quality metric & loss function with support for digital silence target. Useful for training and evaluating audio source separation systems.

Project description

torch-log-wmse-logo

LICENSE GitHub Repo stars

This repository contains the torch implementation of an audio quality metric, logWMSE, originally proposed by Iver Jordal. In addition to the original metric, this implementation can also be used as a loss function for training audio separation and denoising models.

logWMSE is a custom metric and loss function for audio signals that calculates the logarithm (log) of a frequency-weighted (W) Mean Squared Error (MSE). It is designed to address several shortcomings of common audio metrics, most importantly the lack of support for digital silence targets.

Installation

PyPI - Python Version PyPI - Version Number of downloads from PyPI per month

pip install torch-log-wmse

Usage Example

import torch
from torch_log_wmse import LogWMSE

# Tensor shapes
audio_length = 1.0
sample_rate = 44100
audio_stems = 4 # 4 audio stems (e.g. vocals, drums, bass, other)
audio_channels = 2 # stereo
batch = 4 # batch size

# Instantiate logWMSE
# Set `return_as_loss=False` to resturn as a positive metric (Default: True)
# Set `bypass_filter=True` to bypass frequency weighting (Default: False)
log_wmse = LogWMSE(
    audio_length=audio_length,
    sample_rate=sample_rate,
    return_as_loss=True, # optional
    bypass_filter=False, # optional
)

# Generate random inputs (scale between -1 and 1)
audio_lengths_samples = int(audio_length * sample_rate)
unprocessed_audio = 2 * torch.rand(batch, audio_channels, audio_lengths_samples) - 1
processed_audio = 2 * torch.rand(batch, audio_channels, audio_stems, audio_lengths_samples) - 1
target_audio = torch.zeros(batch, audio_channels, audio_stems, audio_lengths_samples)

log_wmse = log_wmse(unprocessed_audio, processed_audio, target_audio)
print(log_wmse)  # Expected output: approx. -18.42

logWMSE accepts three torch tensors of the following shapes:

  • unprocessed_audio: [batch, audio_channels, samples]
  • processed_audio: [batch, audio_channels, audio_stems, samples]
  • target_audio: [batch, audio_channels, audio_stems, samples]

Each dimension being:

  • batch: Number of audio files in a batch (i.e. batch size).
  • audio_channels: Number of channels (i.e. 1 for mono and 2 for stereo).
  • audio_stems: Number of separate audio sources. For source separation, this could be multiple different instruments, vocals, etc. For denoising audio, this will be 1.
  • samples: Number of audio samples (e.g. 1 second of audio @ 44.1kHz is 44100 samples).

Motivation

The goal of this metric is to account for several factors not present in current audio evaluation metrics, such as dealing with digital silence. Mean Squared Error (MSE) is well-defined for digital silence targets, but has its own set of drawbacks. Attempting to mitigate these issues, the following are some attributes of logWMSE:

  • Supports digital silence targets not supported by other audio metrics. i.e. (SI-)SDR, SIR, SAR, ISR, VISQOL_audio, STOI, CDPAM, and VISQOL.
  • Overcomes the small value range issue of MSE (i.e. between 1e-8 and 1e-3), making number formatting and sight-reading easier. It is scaled similarly to SI-SDR for consistency with current benchmark metrics (i.e. 3 is poor, 30 is very good).
  • Scale-invariant, aligns with the frequency sensitivity of human hearing.
  • Invariant to the tiny errors of MSE that are inaudible to humans.
  • Logarithmic, reflecting the logarithmic sensitivity of human hearing.
  • Tailored specifically for audio signals.
Frequency Weighting

To measure the frequencies of a signal closer to that of human hearing, the following frequency weighting is applied. This helps the model effectively pay less attention to errors at frequencies that humans are not sensitive to (e.g. 50 Hz) and give more weight to those that we are acutely tuned to (e.g. 3kHz).

Frequency Weighting

This metric has been constructed with high-fidelity audio in mind (sample rates ≥ 44.1kHz). It theoretically could work for lower sample rates, like 16kHz, but the metric performs an internal resampling to 44.1kHz for consistency across any input sample rates.

Inputs

Unlike many audio quality metrics, logWMSE accepts 3 audio inputs rather than 2:

  • Unprocessed audio (e.g. raw, noisy audio)
  • Processed audio (e.g. denoised or separated audio)
  • Target audio (e.g. ground truth, clean audio)

Typically audio loss functions only use the processed audio and target audio to compare against one another. However, logWMSE requires the initial, unprocessed audio because it needs to be able to measure how well the processed audio was attenuated from the unprocessed version. This adds a factor that accounts for when the input contains silence (digital zero).

This also adds a factor of scale invariance in the sense that the processed audio needs to be scaled appropriately relative to both the unprocessed audio and ground truth. Conceptually, this means that if all 3 inputs are gained by the same arbitrary amount, the metric score will stay the same.

Limitations
  • The metric isn't invariant to arbitrary scaling, polarity inversion, or offsets in the estimated audio relative to the target.
  • Although it incorporates frequency filtering inspired by human auditory sensitivity, it doesn't fully model human auditory perception. For instance, it doesn't consider auditory masking.

Contributing

Contributions are welcome! Please open an issue or submit a pull request if you have any improvements or new features to suggest.

License

This project is licensed under the Apache License 2.0. See LICENSE for details.

Acknowledgments

Thanks to Whitebalance for backing this project.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_log_wmse-0.3.0.tar.gz (44.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_log_wmse-0.3.0-py3-none-any.whl (39.4 kB view details)

Uploaded Python 3

File details

Details for the file torch_log_wmse-0.3.0.tar.gz.

File metadata

  • Download URL: torch_log_wmse-0.3.0.tar.gz
  • Upload date:
  • Size: 44.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.8.18

File hashes

Hashes for torch_log_wmse-0.3.0.tar.gz
Algorithm Hash digest
SHA256 c1737365fc96299e084ca12e48468ade2da8a06bbf0d8596346e8b7d609f605f
MD5 13240bd81d16bf9badedd1d27180aac7
BLAKE2b-256 a16f93ed8b69a7958e60f33a186849cec846d94c8b0e4b1a722641beb721b4f8

See more details on using hashes here.

File details

Details for the file torch_log_wmse-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: torch_log_wmse-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 39.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.8.18

File hashes

Hashes for torch_log_wmse-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 21818934f8bd38cc4a03b500990ce723e8764b69ad649eb7a3bb681d3f4c87a6
MD5 44331e2e9df9575d470a7c20758315cb
BLAKE2b-256 a5f66a20fc80355bc8e9a7bf6413a3c972c59dbb6999a9f0e2cd2b1f6137bf27

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page