Skip to main content

Probability distributions over sequences in pytorch and cupy

Project description

Seqdist

Probability distributions over sequences in pytorch and cupy.

Install

pip install seqdist

How to use

Comparison against builtin pytorch implementation of the standard CTC loss:

sample_inputs = logits, targets, input_lengths, target_lengths = ctc.generate_sample_inputs(T_min=450, T_max=500, N=128, C=20, L_min=80, L_max=100)
print('pytorch loss: {:.4f}'.format(ctc.loss_pytorch(*sample_inputs)))
print('seqdist loss: {:.4f}'.format(ctc.loss_cupy(*sample_inputs)))
pytorch loss: 12.8080
seqdist loss: 12.8080

Speed comparison

Pytorch:

report(benchmark_fwd_bwd(ctc.loss_pytorch, *sample_inputs))
fwd: 4.79ms (4.17-5.33ms)
bwd: 9.69ms (8.33-10.88ms)
tot: 14.47ms (12.67-16.20ms)

Seqdist:

report(benchmark_fwd_bwd(ctc.loss_cupy, *sample_inputs))
fwd: 7.22ms (6.78-7.85ms)
bwd: 6.21ms (5.82-8.57ms)
tot: 13.43ms (12.63-16.41ms)

Alignments

betas = [0.1, 1.0, 10.]
alignments = {'beta={:.1f}'.format(beta): to_np(ctc.soft_alignments(*sample_inputs, beta=beta)) for beta in betas}
alignments['viterbi'] = to_np(ctc.viterbi_alignments(*sample_inputs))
fig, axs = plt.subplots(2, 2, figsize=(15, 8))
for (ax, (title, data)) in zip(np.array(axs).flatten(), alignments.items()):
    ax.imshow(data[:, 0].T, vmax=0.05);
    ax.set_title(title)  

png

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

seqdist-0.0.2.tar.gz (18.1 kB view details)

Uploaded Source

Built Distribution

seqdist-0.0.2-py3-none-any.whl (21.5 kB view details)

Uploaded Python 3

File details

Details for the file seqdist-0.0.2.tar.gz.

File metadata

  • Download URL: seqdist-0.0.2.tar.gz
  • Upload date:
  • Size: 18.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.8.2

File hashes

Hashes for seqdist-0.0.2.tar.gz
Algorithm Hash digest
SHA256 864a444d170f5a98ca8471a46f3ea2a95b6d15e225df4c9af0180c9bff1c00b9
MD5 1f575b4a36968717938d0e2c7b4bbd0b
BLAKE2b-256 8967b2218e0fd099d6d8f186dc18909d89a0bb85e43dcb530f726ece309c0cea

See more details on using hashes here.

File details

Details for the file seqdist-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: seqdist-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 21.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.8.2

File hashes

Hashes for seqdist-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 8cbdbb65fb1de6aded2d9b0e86a82a3be29e1751b44b499bdecd900182f798c6
MD5 cb678c2526b8992ee69688df9c2b8b5c
BLAKE2b-256 1c2b87b3791e4e33cea5cfd8b55516e2053089cc999214153ed06259254d8b5b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page