Collection of audio-focused loss functions in PyTorch.
Project description
Setup
pip install auraloss
If you want to use MelSTFTLoss()
or FIRFilter()
you will need to specify the extra install (librosa and scipy).
pip install auraloss[all]
Usage
import torch
import auraloss
mrstft = auraloss.freq.MultiResolutionSTFTLoss()
input = torch.rand(8,1,44100)
target = torch.rand(8,1,44100)
loss = mrstft(input, target)
NEW: Perceptual weighting with mel scaled spectrograms.
bs = 8
chs = 1
seq_len = 131072
sample_rate = 44100
# some audio you want to compare
target = torch.rand(bs, chs, seq_len)
pred = torch.rand(bs, chs, seq_len)
# define the loss function
loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
fft_sizes=[1024, 2048, 8192],
hop_sizes=[256, 512, 2048],
win_lengths=[1024, 2048, 8192],
scale="mel",
n_bins=128,
sample_rate=sample_rate,
perceptual_weighting=True,
)
# compute
loss = loss_fn(pred, target)
Citation
If you use this code in your work please consider citing us.
@inproceedings{steinmetz2020auraloss,
title={auraloss: {A}udio focused loss functions in {PyTorch}},
author={Steinmetz, Christian J. and Reiss, Joshua D.},
booktitle={Digital Music Research Network One-day Workshop (DMRN+15)},
year={2020}
}
Loss functions
We categorize the loss functions as either time-domain or frequency-domain approaches. Additionally, we include perceptual transforms.
Loss function | Interface | Reference |
---|---|---|
Time domain | ||
Error-to-signal ratio (ESR) | auraloss.time.ESRLoss() |
Wright & Välimäki, 2019 |
DC error (DC) | auraloss.time.DCLoss() |
Wright & Välimäki, 2019 |
Log hyperbolic cosine (Log-cosh) | auraloss.time.LogCoshLoss() |
Chen et al., 2019 |
Signal-to-noise ratio (SNR) | auraloss.time.SNRLoss() |
|
Scale-invariant signal-to-distortion ratio (SI-SDR) |
auraloss.time.SISDRLoss() |
Le Roux et al., 2018 |
Scale-dependent signal-to-distortion ratio (SD-SDR) |
auraloss.time.SDSDRLoss() |
Le Roux et al., 2018 |
Frequency domain | ||
Aggregate STFT | auraloss.freq.STFTLoss() |
Arik et al., 2018 |
Aggregate Mel-scaled STFT | auraloss.freq.MelSTFTLoss(sample_rate) |
|
Multi-resolution STFT | auraloss.freq.MultiResolutionSTFTLoss() |
Yamamoto et al., 2019* |
Random-resolution STFT | auraloss.freq.RandomResolutionSTFTLoss() |
Steinmetz & Reiss, 2020 |
Sum and difference STFT loss | auraloss.freq.SumAndDifferenceSTFTLoss() |
Steinmetz et al., 2020 |
Perceptual transforms | ||
Sum and difference signal transform | auraloss.perceptual.SumAndDifference() |
|
FIR pre-emphasis filters | auraloss.perceptual.FIRFilter() |
Wright & Välimäki, 2019 |
* Wang et al., 2019 also propose a multi-resolution spectral loss (that Engel et al., 2020 follow), but they do not include both the log magnitude (L1 distance) and spectral convergence terms, introduced in Arik et al., 2018, and then extended for the multi-resolution case in Yamamoto et al., 2019.
Examples
Currently we include an example using a set of the loss functions to train a TCN for modeling an analog dynamic range compressor.
For details please refer to the details in examples/compressor
.
We provide pre-trained models, evaluation scripts to compute the metrics in the paper, as well as scripts to retrain models.
There are some more advanced things you can do based upon the STFTLoss
class.
For example, you can compute both linear and log scaled STFT errors as in Engel et al., 2020.
In this case we do not include the spectral convergence term.
stft_loss = auraloss.freq.STFTLoss(
w_log_mag=1.0,
w_lin_mag=1.0,
w_sc=0.0,
)
There is also a Mel-scaled STFT loss, which has some special requirements. This loss requires you set the sample rate as well as specify the correct device.
sample_rate = 44100
melstft_loss = auraloss.freq.MelSTFTLoss(sample_rate, device="cuda")
You can also build a multi-resolution Mel-scaled STFT loss with 64 bins easily. Make sure you pass the correct device where the tensors you are comparing will be.
loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
scale="mel",
n_bins=64,
sample_rate=sample_rate,
device="cuda"
)
If you are computing a loss on stereo audio you may want to consider the sum and difference (mid/side) loss. Below we have shown an example of using this loss function with the perceptual weighting and mel scaling for further perceptual relevance.
target = torch.rand(8, 2, 44100)
pred = torch.rand(8, 2, 44100)
loss_fn = auraloss.freq.SumAndDifferenceSTFTLoss(
fft_sizes=[1024, 2048, 8192],
hop_sizes=[256, 512, 2048],
win_lengths=[1024, 2048, 8192],
perceptual_weighting=True,
sample_rate=44100,
scale="mel",
n_bins=128,
)
loss = loss_fn(pred, target)
Development
Run tests locally with pytest.
python -m pytest
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.