Skip to main content

Torch implementation of various DTW inpsired loss functions

Project description

DTW Inspired Loss Functions

Python package with the implementation of various DTW-inspired loss functions. Each implementation is compatible with PyTorch and can be used for training.

Loss Functions Currently Implemented

Documentation

You can read the package documentation at the link

Note that for now the documentation is generated automatically from the docstrings in the code, through sphinx autodoc. It will be improved in the future, with more examples and explanations.

Installation

pip installation

The easiest way to use this package is to install it via pip

pip install dtw-loss-function

Build from source

Alternatively, you can download the repository and compile it locally via hatchling

pip install hatchling
git clone https://github.com/jesus-333/dtw_loss_functions.git
cd dtw_loss_functions
hatchling build && pip install .

Usage Examples

Each loss function is implemented as a module inside the package.

Block-DTW Example

from dtw_loss_functions import block_dtw
import torch

block_size = 25
use_cuda = torch.cuda.is_available()
block_dtw_loss = block_dtw.block_dtw(block_size, sdtw_config = {'use_cuda' : use_cuda})

batch_size = 5
time_samples = 300
channels = 1

device = 'cuda' if use_cuda else 'cpu'
x   = torch.randn(batch_size, time_samples, channels).to(device)
x_r = torch.randn(batch_size, time_samples, channels).to(device)

output_block_dtw = block_dtw_loss(x, x_r)

SoftDTW Example

This package offer several implementations for the SoftDTW algorithm.

Mehran Maghoumi's (mag) implementation

import torch
from dtw_loss_functions import soft_dtw

use_cuda = torch.cuda.is_available()

sdtw_loss = soft_dtw.soft_dtw(implementation = 'mag', sdtw_config = {'use_cuda' : use_cuda, 'gamma' : 0.1})

batch_size = 5
time_samples = 300
channels = 1

device = 'cuda' if use_cuda else 'cpu'
x   = torch.randn(batch_size, time_samples, channels).to(device)
x_r = torch.randn(batch_size, time_samples, channels).to(device)

output_sdtw = sdtw_loss(x, x_r)

Ron Shapira Weber's (ron) implementation

import torch
from dtw_loss_functions import soft_dtw

sdtw_loss = soft_dtw.soft_dtw(implementation = 'ron', sdtw_config = {'gamma' : 0.1, 'dist' : 'sqeuclidean'})

batch_size = 5
time_samples = 300
channels = 1

device = 'cuda' if torch.cuda.is_available() else 'cpu'
x   = torch.randn(batch_size, time_samples, channels).to(device)
x_r = torch.randn(batch_size, time_samples, channels).to(device)

output_sdtw = sdtw_loss(x, x_r)

Citation

If you use this package refer to the citation page for all the info regarding the works to cite.

Benchmark

The benchmark page offers some comparison between the different loss functions included in this package and show how to use the benchmark module.

Important notes about Repository Name

The original name of this repository was block-sdtw because it was initially created to develop only the code related to block-DTW loss function. As with many academic projects, the code was initially "not very well organized", so I decided to refactor it and transform it into a Python package for easier use. Since my implementation of Block-DTW relies on Mehran Maghoumi's implementation SoftDTW, I decided to also include his SoftDTW inside the package. Due to some testing, I also had the opportunity to implement the OTW loss function.

At this point, given the presence of 3 different loss functions within the code base, the name block-sdtw for the package seemed a bit misleading to me. So, I decided to change the name from block-sdtw to dtw-loss-function.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dtw_loss_functions-1.0.5.tar.gz (8.5 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dtw_loss_functions-1.0.5-py2.py3-none-any.whl (40.7 kB view details)

Uploaded Python 2Python 3

File details

Details for the file dtw_loss_functions-1.0.5.tar.gz.

File metadata

  • Download URL: dtw_loss_functions-1.0.5.tar.gz
  • Upload date:
  • Size: 8.5 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for dtw_loss_functions-1.0.5.tar.gz
Algorithm Hash digest
SHA256 9843989e3f746619ef8b93649645bd293b227a7ad2c810326a1e1e591db3ca8f
MD5 1bc431a0f5039eaf1cd1d29af2df92f1
BLAKE2b-256 2dc2dd66f3b06ee2b62e88bdfcf6345d3948e92187eae80bf58f31e09faaef7c

See more details on using hashes here.

Provenance

The following attestation bundles were made for dtw_loss_functions-1.0.5.tar.gz:

Publisher: python-publish.yml on jesus-333/dtw_loss_functions

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file dtw_loss_functions-1.0.5-py2.py3-none-any.whl.

File metadata

File hashes

Hashes for dtw_loss_functions-1.0.5-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 282f984d7f46bc843a93f664a171e0abac4f5c676206cd87c94e3fd6ded8d017
MD5 f34fd9636ee4af638866edb17404ee30
BLAKE2b-256 cf988373c8b37f30424ddeea249583ec07c23f28a8047045f525246d0888ee52

See more details on using hashes here.

Provenance

The following attestation bundles were made for dtw_loss_functions-1.0.5-py2.py3-none-any.whl:

Publisher: python-publish.yml on jesus-333/dtw_loss_functions

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page