Skip to main content

A PyTorch implementation of Spectral Optimal Transport (SOT) losses for audio.

Project description

Spectral Optimal Transport Losses for PyTorch

PyPI Version Paper (arXiv)

This repository contains an implementation of Spectral Optimal Transport (SOT) loss functions for PyTorch, with a pip-installable package sot-loss. SOT loss functions are differentiable spectral losses which compare the spectra of two audio signals using optimal transport principles. These loss functions can be used for training neural networks in audio processing tasks, particularly those involving DDSP. It can also be used more generally as a metric for audio signal comparison.

SOT does this

Multi-Scale Spectral loss and others do this

Installation

You can install the sot-loss package using pip:

pip install sot-loss

Usage

The primary components of this package are the Wasserstein1DLoss and MultiResolutionSOTLoss classes, which can be used as PyTorch loss functions. Here is a basic example of how to use the Wasserstein1DLoss:

import torch
from sot import Wasserstein1DLoss

# Create some dummy audio signals
x = torch.randn(4, 16000)
y = torch.randn(4, 16000)

# Initialize the SOT loss
sot_loss = Wasserstein1DLoss(transform='stft', 
                             fft_size=2048,
                             hop_length=512, 
                             sample_rate=16000, 
                             window='flattop', 
                             square_magnitude=True)

# Compute the loss
loss = sot_loss(x, y)
print(loss)

Using your own mapping audio -> 2D representation:

x_spec = custom_transform(x) # batch, channels, time
y_spec = custom_transform(y) 
x_positions = get_custom_positions(x_spec) # channels
y_positions = get_custom_positions(y_spec)

sot_loss = Wasserstein1DLoss(transform='identity',
                             # other non-transform parameters can go here
                             balanced=True,
                             normalize=True,
                             )
loss = sot_loss(x_spec, y_spec, x_positions=x_positions, y_positions=y_positions)
print(loss)

Advanced Usage

The Wasserstein1DLoss and MultiResolutionSOTLoss classes offer a range of parameters to customize the spectral representation and the loss calculation.

Spectral Transform Parameters

These parameters are available in both Wasserstein1DLoss and MultiResolutionSOTLoss.

Transform parameters (if using built-in transforms):

Argument Type Default Description
transform str 'stft' The spectral transform to use. One of 'stft', 'mel', 'cqt', or 'identity'.
fft_size, hop_length, win_length int 1024, 256, None Your typical STFT parameters.
window str 'flattop' The window function to use for STFT and CQT.
n_mels int 128 Number of Mel bins for the Mel spectrogram.
n_bins, bins_per_octave, fmin, fmax, sample_rate int, int, float, float 84, 36, 32.7, None, 22050 CQT parameters.
gamma int 0 VQT parameter which reduces kernel lengths for low frequencies. 0 for traditional CQT (see This paper) .
bin_position_scaling str 'normalized' Defines how the ground distance for the Wasserstein calculation is measured. Affects how the bin positions for the transforms are calculated. One of 'normalized', 'normalized_linear', or 'absolute'.

Loss parameters (applies even if using custom transforms):

Argument Type Default Description
square_magnitude bool False If True, computes the loss on the squared magnitude of the spectrum (power).
dim int -1 The dimension along which to compute the Wasserstein distance. -1 for frequency, -2 for time.
normalize bool True If True, normalizes the spectral magnitudes to sum to 1, treating them as probability distributions.
balanced bool True If True and normalize is True, both spectra are normalized to sum to 1 independently. If False and normalize is True, the second spectrum is scaled relative to the first.
p int 2 The order of the Wasserstein distance.
quantile_lowpass bool False If True, applies a frequency cutoff by zeroing out distances for quantiles above 1.0. This is useful when balanced is False.

The MultiResolutionSOTLoss combines multiple Wasserstein1DLoss instances, each with a different set of STFT parameters.

Argument Type Default Description
fft_sizes list [1024, 2048, 512] A list of FFT sizes to use for each resolution.
hop_lengths list [256, 512, 128] A list of hop lengths to use for each resolution.
win_lengths list None A list of window lengths to use for each resolution. If None, defaults to fft_sizes.

About the Paper

This is the also the official repository for the paper "Unsupervised Harmonic Parameter Estimation Using Differentiable DSP and Spectral Optimal Transport.", by Bernardo Torres, Geoffroy Peeters, and Gaël Richard. Check out the poster here.

For repoducing the results from the paper, please check out the paper branch.

Citation

If you find our work useful or use it in your research, you can cite it using:

@inproceedings{torres2024unsupervised,
  title={Unsupervised harmonic parameter estimation using differentiable DSP and spectral optimal transport},
  author={Torres, Bernardo and Peeters, Geoffroy and Richard, Ga{\"e}l},
  booktitle={ICASSP 2024-2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
  pages={1176--1180},
  year={2024},
  organization={IEEE}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sot_loss-0.1.4.tar.gz (20.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sot_loss-0.1.4-py3-none-any.whl (19.3 kB view details)

Uploaded Python 3

File details

Details for the file sot_loss-0.1.4.tar.gz.

File metadata

  • Download URL: sot_loss-0.1.4.tar.gz
  • Upload date:
  • Size: 20.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for sot_loss-0.1.4.tar.gz
Algorithm Hash digest
SHA256 d5a36ee67f88a91ba3651c2e1fc991ca79f8f657645cf8ee66f53381fa2e042e
MD5 9bf926189e757108983a8a808a4ab7e1
BLAKE2b-256 3a6a37a147132090872d8a5f73172207242ebf065b4007f6f27019ea73d2c416

See more details on using hashes here.

File details

Details for the file sot_loss-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: sot_loss-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 19.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for sot_loss-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 c4886874f001599c8ef95318795740c71a3bc1584fdb049733a02f560394c345
MD5 65002011f60d3050ff477494bda343a3
BLAKE2b-256 4ef20adbff4d8dcc6b790c23cafeeebcf2b9e71a0017d460d882f833467de75b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page