Torch implementation of various DTW inpsired loss functions
Project description
DTW Inspired Loss Functions
Python package with the implementation of various DTW-inspired loss functions. Each implementation is compatible with PyTorch and can be used for training.
Loss Functions Currently Implemented
- SDTW. CUDA implementation of the SofDTW algorithm. This package offers several implementations of this algorithm. Currently
- pytorch-softdtw-cuda by Mehran Maghoumi
- pysdtw by Antoine Loriette
- sdtw-cuda-torch by BGU-CS-VIL (implemented by Ron Shapira Weber)
- BlockDTW. Alternative version of SDTW, where a block-wise computation is used to improve performance. See the paper BlockDTW: Efficient and Scalable Similarity Search Algorithm for Healthcare-Focused Time-Series for more details.
- OTW. Implementation of the Optimal Transport Warping function. See the paper OTW: Optimal Transport Warping for Time Series for more details.
Documentation
You can read the package documentation at the link
Note that for now the documentation is generated automatically from the docstrings in the code, through sphinx autodoc. It will be improved in the future, with more examples and explanations.
Installation
pip installation
The easiest way to use this package is to install it via pip
pip install dtw-loss-function
Build from source
Alternatively, you can download the repository and compile it locally via hatchling
pip install hatchling
git clone https://github.com/jesus-333/dtw_loss_functions.git
cd dtw_loss_functions
hatchling build && pip install .
Usage Examples
Each loss function is implemented as a module inside the package.
Block-DTW Example
from dtw_loss_functions import block_dtw
import torch
block_size = 25
use_cuda = torch.cuda.is_available()
block_dtw_loss = block_dtw.block_dtw(block_size, sdtw_config = {'use_cuda' : use_cuda})
batch_size = 5
time_samples = 300
channels = 1
device = 'cuda' if use_cuda else 'cpu'
x = torch.randn(batch_size, time_samples, channels).to(device)
x_r = torch.randn(batch_size, time_samples, channels).to(device)
output_block_dtw = block_dtw_loss(x, x_r)
SoftDTW Example
This package offer several implementations for the SoftDTW algorithm.
Mehran Maghoumi's (mag) implementation
import torch
from dtw_loss_functions import soft_dtw
use_cuda = torch.cuda.is_available()
sdtw_loss = soft_dtw.soft_dtw(implementation = 'mag', sdtw_config = {'use_cuda' : use_cuda, 'gamma' : 0.1})
batch_size = 5
time_samples = 300
channels = 1
device = 'cuda' if use_cuda else 'cpu'
x = torch.randn(batch_size, time_samples, channels).to(device)
x_r = torch.randn(batch_size, time_samples, channels).to(device)
output_sdtw = sdtw_loss(x, x_r)
Ron Shapira Weber's (ron) implementation
import torch
from dtw_loss_functions import soft_dtw
sdtw_loss = soft_dtw.soft_dtw(implementation = 'ron', sdtw_config = {'gamma' : 0.1, 'dist' : 'sqeuclidean'})
batch_size = 5
time_samples = 300
channels = 1
device = 'cuda' if torch.cuda.is_available() else 'cpu'
x = torch.randn(batch_size, time_samples, channels).to(device)
x_r = torch.randn(batch_size, time_samples, channels).to(device)
output_sdtw = sdtw_loss(x, x_r)
Citation
If you use this package refer to the citation page for all the info regarding the works to cite.
Benchmark
The benchmark page offers some comparison between the different loss functions included in this package and show how to use the benchmark module.
Important notes about Repository Name
The original name of this repository was block-sdtw because it was initially created to develop only the code related to block-DTW loss function. As with many academic projects, the code was initially "not very well organized", so I decided to refactor it and transform it into a Python package for easier use.
Since my implementation of Block-DTW relies on Mehran Maghoumi's implementation SoftDTW, I decided to also include his SoftDTW inside the package.
Due to some testing, I also had the opportunity to implement the OTW loss function.
At this point, given the presence of 3 different loss functions within the code base, the name block-sdtw for the package seemed a bit misleading to me. So, I decided to change the name from block-sdtw to dtw-loss-function.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file dtw_loss_functions-1.0.5.tar.gz.
File metadata
- Download URL: dtw_loss_functions-1.0.5.tar.gz
- Upload date:
- Size: 8.5 MB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9843989e3f746619ef8b93649645bd293b227a7ad2c810326a1e1e591db3ca8f
|
|
| MD5 |
1bc431a0f5039eaf1cd1d29af2df92f1
|
|
| BLAKE2b-256 |
2dc2dd66f3b06ee2b62e88bdfcf6345d3948e92187eae80bf58f31e09faaef7c
|
Provenance
The following attestation bundles were made for dtw_loss_functions-1.0.5.tar.gz:
Publisher:
python-publish.yml on jesus-333/dtw_loss_functions
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
dtw_loss_functions-1.0.5.tar.gz -
Subject digest:
9843989e3f746619ef8b93649645bd293b227a7ad2c810326a1e1e591db3ca8f - Sigstore transparency entry: 1096915685
- Sigstore integration time:
-
Permalink:
jesus-333/dtw_loss_functions@38c77649d81e4c4993f27a1d9b3e3b668d6c4d18 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/jesus-333
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-publish.yml@38c77649d81e4c4993f27a1d9b3e3b668d6c4d18 -
Trigger Event:
workflow_dispatch
-
Statement type:
File details
Details for the file dtw_loss_functions-1.0.5-py2.py3-none-any.whl.
File metadata
- Download URL: dtw_loss_functions-1.0.5-py2.py3-none-any.whl
- Upload date:
- Size: 40.7 kB
- Tags: Python 2, Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
282f984d7f46bc843a93f664a171e0abac4f5c676206cd87c94e3fd6ded8d017
|
|
| MD5 |
f34fd9636ee4af638866edb17404ee30
|
|
| BLAKE2b-256 |
cf988373c8b37f30424ddeea249583ec07c23f28a8047045f525246d0888ee52
|
Provenance
The following attestation bundles were made for dtw_loss_functions-1.0.5-py2.py3-none-any.whl:
Publisher:
python-publish.yml on jesus-333/dtw_loss_functions
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
dtw_loss_functions-1.0.5-py2.py3-none-any.whl -
Subject digest:
282f984d7f46bc843a93f664a171e0abac4f5c676206cd87c94e3fd6ded8d017 - Sigstore transparency entry: 1096915689
- Sigstore integration time:
-
Permalink:
jesus-333/dtw_loss_functions@38c77649d81e4c4993f27a1d9b3e3b668d6c4d18 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/jesus-333
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-publish.yml@38c77649d81e4c4993f27a1d9b3e3b668d6c4d18 -
Trigger Event:
workflow_dispatch
-
Statement type: