A PyTorch implementation of Spectral Optimal Transport (SOT) losses for audio.
Project description
Spectral Optimal Transport Losses for PyTorch
This repository contains an implementation of Spectral Optimal Transport (SOT) loss functions for PyTorch, with a pip-installable package sot-loss. SOT loss functions are differentiable spectral losses which compare the spectra of two audio signals using optimal transport principles. These loss functions can be used for training neural networks in audio processing tasks, particularly those involving DDSP. It can also be used more generally as a metric for audio signal comparison.
|
SOT does this |
Multi-Scale Spectral loss and others do this |
Installation
You can install the sot-loss package using pip:
pip install sot-loss
Usage
The primary components of this package are the Wasserstein1DLoss and MultiResolutionSOTLoss classes, which can be used as PyTorch loss functions. Here is a basic example of how to use the Wasserstein1DLoss:
import torch
from sot import Wasserstein1DLoss
# Create some dummy audio signals
x = torch.randn(4, 16000)
y = torch.randn(4, 16000)
# Initialize the SOT loss
sot_loss = Wasserstein1DLoss(transform='stft',
fft_size=2048,
hop_length=512,
sample_rate=16000,
window='flattop',
square_magnitude=True)
# Compute the loss
loss = sot_loss(x, y)
print(loss)
Using your own mapping audio -> 2D representation:
x_spec = custom_transform(x) # batch, channels, time
y_spec = custom_transform(y)
x_positions = get_custom_positions(x_spec) # channels
y_positions = get_custom_positions(y_spec)
sot_loss = Wasserstein1DLoss(transform='identity',
# other non-transform parameters can go here
balanced=True,
normalize=True,
)
loss = sot_loss(x_spec, y_spec, x_positions=x_positions, y_positions=y_positions)
print(loss)
Advanced Usage
The Wasserstein1DLoss and MultiResolutionSOTLoss classes offer a range of parameters to customize the spectral representation and the loss calculation.
Spectral Transform Parameters
These parameters are available in both Wasserstein1DLoss and MultiResolutionSOTLoss.
Transform parameters (if using built-in transforms):
| Argument | Type | Default | Description |
|---|---|---|---|
transform |
str | 'stft' |
The spectral transform to use. One of 'stft', 'mel', 'cqt', or 'identity'. |
fft_size, hop_length, win_length |
int | 1024, 256, None |
Your typical STFT parameters. |
window |
str | 'flattop' |
The window function to use for STFT and CQT. |
n_mels |
int | 128 |
Number of Mel bins for the Mel spectrogram. |
n_bins, bins_per_octave, fmin, fmax, sample_rate |
int, int, float, float | 84, 36, 32.7, None, 22050 |
CQT parameters. |
gamma |
int | 0 |
VQT parameter which reduces kernel lengths for low frequencies. 0 for traditional CQT (see This paper) . |
bin_position_scaling |
str | 'normalized' |
Defines how the ground distance for the Wasserstein calculation is measured. Affects how the bin positions for the transforms are calculated. One of 'normalized', 'normalized_linear', or 'absolute'. |
Loss parameters (applies even if using custom transforms):
| Argument | Type | Default | Description |
|---|---|---|---|
square_magnitude |
bool | False |
If True, computes the loss on the squared magnitude of the spectrum (power). |
dim |
int | -1 |
The dimension along which to compute the Wasserstein distance. -1 for frequency, -2 for time. |
normalize |
bool | True |
If True, normalizes the spectral magnitudes to sum to 1, treating them as probability distributions. |
balanced |
bool | True |
If True and normalize is True, both spectra are normalized to sum to 1 independently. If False and normalize is True, the second spectrum is scaled relative to the first. |
p |
int | 2 |
The order of the Wasserstein distance. |
quantile_lowpass |
bool | False |
If True, applies a frequency cutoff by zeroing out distances for quantiles above 1.0. This is useful when balanced is False. |
The MultiResolutionSOTLoss combines multiple Wasserstein1DLoss instances, each with a different set of STFT parameters.
| Argument | Type | Default | Description |
|---|---|---|---|
fft_sizes |
list | [1024, 2048, 512] |
A list of FFT sizes to use for each resolution. |
hop_lengths |
list | [256, 512, 128] |
A list of hop lengths to use for each resolution. |
win_lengths |
list | None |
A list of window lengths to use for each resolution. If None, defaults to fft_sizes. |
About the Paper
This is the also the official repository for the paper "Unsupervised Harmonic Parameter Estimation Using Differentiable DSP and Spectral Optimal Transport.", by Bernardo Torres, Geoffroy Peeters, and Gaël Richard. Check out the poster here.
For repoducing the results from the paper, please check out the paper branch.
Citation
If you find our work useful or use it in your research, you can cite it using:
@inproceedings{torres2024unsupervised,
title={Unsupervised harmonic parameter estimation using differentiable DSP and spectral optimal transport},
author={Torres, Bernardo and Peeters, Geoffroy and Richard, Ga{\"e}l},
booktitle={ICASSP 2024-2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
pages={1176--1180},
year={2024},
organization={IEEE}
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file sot_loss-0.1.4.tar.gz.
File metadata
- Download URL: sot_loss-0.1.4.tar.gz
- Upload date:
- Size: 20.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d5a36ee67f88a91ba3651c2e1fc991ca79f8f657645cf8ee66f53381fa2e042e
|
|
| MD5 |
9bf926189e757108983a8a808a4ab7e1
|
|
| BLAKE2b-256 |
3a6a37a147132090872d8a5f73172207242ebf065b4007f6f27019ea73d2c416
|
File details
Details for the file sot_loss-0.1.4-py3-none-any.whl.
File metadata
- Download URL: sot_loss-0.1.4-py3-none-any.whl
- Upload date:
- Size: 19.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c4886874f001599c8ef95318795740c71a3bc1584fdb049733a02f560394c345
|
|
| MD5 |
65002011f60d3050ff477494bda343a3
|
|
| BLAKE2b-256 |
4ef20adbff4d8dcc6b790c23cafeeebcf2b9e71a0017d460d882f833467de75b
|