Skip to main content

Torch implementation of various DTW inpsired loss functions

Project description

DTW Inspired Loss Functions

Python package with the implementation of various DTW-inspired loss functions. Each implementation is compatible with PyTorch and can be used for training.

Loss Functions Currently Implemented

Documentation

You can read the package documentation at the link

Note that for now the documentation is generated automatically from the docstrings in the code, through sphinx autodoc. It will be improved in the future, with more examples and explanations.

Important Notes about Repository Name

The original name of this repository was block-sdtw because it was initially created to develop only the code related to block-DTW loss function. As with many academic projects, the code was initially "not very well organized", so I decided to refactor it and transform it into a Python package for easier use. Since my implementation of Block-DTW relies on Mehran Maghoumi's implementation of SoftDTW, I decided to also include SoftDTW inside the package. Due to some testing, I also had the opportunity to implement the OTW loss function.

At this point, given the presence of 3 different loss functions within the code base, the name block-sdtw for the package seemed a bit misleading to me. So, I decided to change the name from block-sdtw to dtw-loss-function.

Installation

pip installation

The easiest way to use this package is to install it via pip

pip install dtw-loss-function

Build from source

Alternatively, you can download the repository and compile it locally via hatchling

pip install hatchling
git clone https://github.com/jesus-333/dtw_loss_functions.git
cd dtw_loss_functions
hatchling build && pip install .

Usage Examples

Each loss function is implemented as a class inside the package.

Block-DTW Example

import torch
from dtw_loss_functions import block_dtw

block_size = 25

use_cuda = torch.cuda.is_available()
device = 'cuda' if use_cuda else 'cpu'

batch_size = 5
time_samples = 300
channels = 1

x   = torch.randn(batch_size, time_samples, channels).to(device)
x_r = torch.randn(batch_size, time_samples, channels).to(device)

block_dtw_loss = block_dtw.block_dtw(block_size, use_cuda)

output_block_dtw = block_dtw_loss(x, x_r)

SoftDTW Example

import torch
from dtw_loss_functions import soft_dtw

use_cuda = torch.cuda.is_available()
device = 'cuda' if use_cuda else 'cpu'

batch_size = 5
time_samples = 300
channels = 1

x   = torch.randn(batch_size, time_samples, channels).to(device)
x_r = torch.randn(batch_size, time_samples, channels).to(device)

sdtw_loss = soft_dtw.soft_dtw(use_cuda = use_cuda)

output_sdtw = sdtw_loss(x, x_r)

Citation

If you use this package refer to the citation file for all the info regarding the works to cite.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dtw_loss_functions-1.0.4.tar.gz (8.4 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dtw_loss_functions-1.0.4-py2.py3-none-any.whl (35.0 kB view details)

Uploaded Python 2Python 3

File details

Details for the file dtw_loss_functions-1.0.4.tar.gz.

File metadata

  • Download URL: dtw_loss_functions-1.0.4.tar.gz
  • Upload date:
  • Size: 8.4 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for dtw_loss_functions-1.0.4.tar.gz
Algorithm Hash digest
SHA256 6c0ef871cd05eb31b5a6fdf00dc69c75a317dfc2022e2b968169bc746275ea6a
MD5 63a2d07f7c637985adf3fd2403c2e888
BLAKE2b-256 283e98b29fa62b04d0aef1c3d780db8e4f9aaec8ede339d19d86797661a4c457

See more details on using hashes here.

Provenance

The following attestation bundles were made for dtw_loss_functions-1.0.4.tar.gz:

Publisher: python-publish.yml on jesus-333/dtw_loss_functions

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file dtw_loss_functions-1.0.4-py2.py3-none-any.whl.

File metadata

File hashes

Hashes for dtw_loss_functions-1.0.4-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 8fa70c301a45185b0a3eed4ee4bca7530b0031aaf292b9149d4b026b5cbe42c3
MD5 efe04e38513b77365af2a5d7c3ec3b4f
BLAKE2b-256 e5da2b888f59c33a04ee23b1a1895c5af5e415c30f8b00ac80e111a3a06f256f

See more details on using hashes here.

Provenance

The following attestation bundles were made for dtw_loss_functions-1.0.4-py2.py3-none-any.whl:

Publisher: python-publish.yml on jesus-333/dtw_loss_functions

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page