Skip to main content

Torch implementation of Soft-DTW, supports CUDA devices.

Project description

pysdtw

Torch implementation of the Soft-DTW algorithm, supports both cpu and CUDA hardware.

Note: This repository started as a fork from this project.

Installation

This package is available on pypi and depends on pytorch and numba.

Install with:

pip install pysdtw

Usage

import pysdtw

# the input data includes a batch dimension
X = torch.rand((10, 5, 7), requires_grad=True)
Y = torch.rand((10, 9, 7))

# optionally choose a pairwise distance function
fun = pysdtw.distance.pairwise_l2_squared

# create the SoftDTW distance function
sdtw = pysdtw.SoftDTW(gamma=1.0, dist_func=fun, use_cuda=False)

# soft-DTW discrepancy, approaches DTW as gamma -> 0
res = sdtw(X, Y)

# define a loss, which gradient can be backpropagated
loss = res.sum()
loss.backward()

# X.grad now contains the gradient with respect to the loss

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pysdtw-0.0.5.tar.gz (6.1 kB view details)

Uploaded Source

Built Distribution

pysdtw-0.0.5-py3-none-any.whl (7.1 kB view details)

Uploaded Python 3

File details

Details for the file pysdtw-0.0.5.tar.gz.

File metadata

  • Download URL: pysdtw-0.0.5.tar.gz
  • Upload date:
  • Size: 6.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for pysdtw-0.0.5.tar.gz
Algorithm Hash digest
SHA256 a1e9a25c250f0da3a886dc7d779e9f3ab8b8cbbdba607427a9fef146b25ae4e8
MD5 731dee534fd5930b196f6c90140bc52c
BLAKE2b-256 517ce201ffa0144916bd2c26a7e5a2d2aeb6f1ecadd89502bdded9d6deb830c1

See more details on using hashes here.

File details

Details for the file pysdtw-0.0.5-py3-none-any.whl.

File metadata

  • Download URL: pysdtw-0.0.5-py3-none-any.whl
  • Upload date:
  • Size: 7.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for pysdtw-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 1a46d46fb00408a300a2dbe006ea957d3efef20502cc98d24c5b22e3ee867653
MD5 4f1ef667eb9e3b2894ff93316d0c236b
BLAKE2b-256 ac0062c234127146cc4695fe2e28fb827c59b4f2d417deb09b9a0c21eb99781f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page