Skip to main content

Collection of audio-focused loss functions in PyTorch.

Project description

auraloss

A collection of audio-focused loss functions in PyTorch.

[PDF]

Setup

pip install auraloss

If you want to use MelSTFTLoss() or FIRFilter() you will need to specify the extra install (librosa and scipy).

pip install auraloss[all]

Usage

import torch
import auraloss

mrstft = auraloss.freq.MultiResolutionSTFTLoss()

input = torch.rand(8,1,44100)
target = torch.rand(8,1,44100)

loss = mrstft(input, target)

NEW: Perceptual weighting with mel scaled spectrograms.

bs = 8
chs = 1
seq_len = 131072
sample_rate = 44100

# some audio you want to compare
target = torch.rand(bs, chs, seq_len)
pred = torch.rand(bs, chs, seq_len)

# define the loss function
loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
    fft_sizes=[1024, 2048, 8192],
    hop_sizes=[256, 512, 2048],
    win_lengths=[1024, 2048, 8192],
    scale="mel",
    n_bins=128,
    sample_rate=sample_rate,
    perceptual_weighting=True,
)

# compute
loss = loss_fn(pred, target)

Citation

If you use this code in your work please consider citing us.

@inproceedings{steinmetz2020auraloss,
    title={auraloss: {A}udio focused loss functions in {PyTorch}},
    author={Steinmetz, Christian J. and Reiss, Joshua D.},
    booktitle={Digital Music Research Network One-day Workshop (DMRN+15)},
    year={2020}
}

Loss functions

We categorize the loss functions as either time-domain or frequency-domain approaches. Additionally, we include perceptual transforms.

Loss function Interface Reference
Time domain
Error-to-signal ratio (ESR) auraloss.time.ESRLoss() Wright & Välimäki, 2019
DC error (DC) auraloss.time.DCLoss() Wright & Välimäki, 2019
Log hyperbolic cosine (Log-cosh) auraloss.time.LogCoshLoss() Chen et al., 2019
Signal-to-noise ratio (SNR) auraloss.time.SNRLoss()
Scale-invariant signal-to-distortion
ratio (SI-SDR)
auraloss.time.SISDRLoss() Le Roux et al., 2018
Scale-dependent signal-to-distortion
ratio (SD-SDR)
auraloss.time.SDSDRLoss() Le Roux et al., 2018
Frequency domain
Aggregate STFT auraloss.freq.STFTLoss() Arik et al., 2018
Aggregate Mel-scaled STFT auraloss.freq.MelSTFTLoss(sample_rate)
Multi-resolution STFT auraloss.freq.MultiResolutionSTFTLoss() Yamamoto et al., 2019*
Random-resolution STFT auraloss.freq.RandomResolutionSTFTLoss() Steinmetz & Reiss, 2020
Sum and difference STFT loss auraloss.freq.SumAndDifferenceSTFTLoss() Steinmetz et al., 2020
Perceptual transforms
Sum and difference signal transform auraloss.perceptual.SumAndDifference()
FIR pre-emphasis filters auraloss.perceptual.FIRFilter() Wright & Välimäki, 2019

* Wang et al., 2019 also propose a multi-resolution spectral loss (that Engel et al., 2020 follow), but they do not include both the log magnitude (L1 distance) and spectral convergence terms, introduced in Arik et al., 2018, and then extended for the multi-resolution case in Yamamoto et al., 2019.

Examples

Currently we include an example using a set of the loss functions to train a TCN for modeling an analog dynamic range compressor. For details please refer to the details in examples/compressor. We provide pre-trained models, evaluation scripts to compute the metrics in the paper, as well as scripts to retrain models.

There are some more advanced things you can do based upon the STFTLoss class. For example, you can compute both linear and log scaled STFT errors as in Engel et al., 2020. In this case we do not include the spectral convergence term.

stft_loss = auraloss.freq.STFTLoss(
    w_log_mag=1.0, 
    w_lin_mag=1.0, 
    w_sc=0.0,
)

There is also a Mel-scaled STFT loss, which has some special requirements. This loss requires you set the sample rate as well as specify the correct device.

sample_rate = 44100
melstft_loss = auraloss.freq.MelSTFTLoss(sample_rate, device="cuda")

You can also build a multi-resolution Mel-scaled STFT loss with 64 bins easily. Make sure you pass the correct device where the tensors you are comparing will be.

loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
    scale="mel", 
    n_bins=64,
    sample_rate=sample_rate,
    device="cuda"
)

If you are computing a loss on stereo audio you may want to consider the sum and difference (mid/side) loss. Below we have shown an example of using this loss function with the perceptual weighting and mel scaling for further perceptual relevance.

target = torch.rand(8, 2, 44100)
pred = torch.rand(8, 2, 44100)

loss_fn = auraloss.freq.SumAndDifferenceSTFTLoss(
    fft_sizes=[1024, 2048, 8192],
    hop_sizes=[256, 512, 2048],
    win_lengths=[1024, 2048, 8192],
    perceptual_weighting=True,
    sample_rate=44100,
    scale="mel",
    n_bins=128,
)

loss = loss_fn(pred, target)

Development

Run tests locally with pytest.

python -m pytest

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

auraloss-0.4.0.tar.gz (17.6 kB view details)

Uploaded Source

Built Distribution

auraloss-0.4.0-py3-none-any.whl (16.7 kB view details)

Uploaded Python 3

File details

Details for the file auraloss-0.4.0.tar.gz.

File metadata

  • Download URL: auraloss-0.4.0.tar.gz
  • Upload date:
  • Size: 17.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.16

File hashes

Hashes for auraloss-0.4.0.tar.gz
Algorithm Hash digest
SHA256 86eb3bce81aaf579d9e2df59b0cae7ab8de793902c344af72f255d53c3d5c954
MD5 1ce3f70345d0b08511aa5234f31ddb64
BLAKE2b-256 ec552a3bbafa3e947b972c81f852c6c31cc359a21848908f2aabd35b34b532e9

See more details on using hashes here.

File details

Details for the file auraloss-0.4.0-py3-none-any.whl.

File metadata

  • Download URL: auraloss-0.4.0-py3-none-any.whl
  • Upload date:
  • Size: 16.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.16

File hashes

Hashes for auraloss-0.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 7ca1cfff0d04db9c1269038a1c527fc38bc4756dd33bfff115889a3461d97d37
MD5 f4bfb385f27ddbba9b1202419e057601
BLAKE2b-256 a6ab8df927d3f0951cf67ca5973d89b35bcbda1777a4c78bf90a853d02d91285

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page