Skip to main content

Probability distributions over sequences in pytorch and cupy

Project description

Seqdist

Probability distributions over sequences in pytorch and cupy.

Install

pip install seqdist

How to use

Comparison against builtin pytorch implementation of the standard CTC loss:

sample_inputs = logits, targets, input_lengths, target_lengths = ctc.generate_sample_inputs(T_min=450, T_max=500, N=128, C=20, L_min=80, L_max=100)
print('pytorch loss: {:.4f}'.format(ctc.loss_pytorch(*sample_inputs)))
print('seqdist loss: {:.4f}'.format(ctc.loss_cupy(*sample_inputs)))
pytorch loss: 12.8080
seqdist loss: 12.8080

Speed comparison

Pytorch:

report(benchmark_fwd_bwd(ctc.loss_pytorch, *sample_inputs))
fwd: 4.79ms (4.17-5.33ms)
bwd: 9.69ms (8.33-10.88ms)
tot: 14.47ms (12.67-16.20ms)

Seqdist:

report(benchmark_fwd_bwd(ctc.loss_cupy, *sample_inputs))
fwd: 7.22ms (6.78-7.85ms)
bwd: 6.21ms (5.82-8.57ms)
tot: 13.43ms (12.63-16.41ms)

Alignments

betas = [0.1, 1.0, 10.]
alignments = {'beta={:.1f}'.format(beta): to_np(ctc.soft_alignments(*sample_inputs, beta=beta)) for beta in betas}
alignments['viterbi'] = to_np(ctc.viterbi_alignments(*sample_inputs))
fig, axs = plt.subplots(2, 2, figsize=(15, 8))
for (ax, (title, data)) in zip(np.array(axs).flatten(), alignments.items()):
    ax.imshow(data[:, 0].T, vmax=0.05);
    ax.set_title(title)  

png

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

ont_seqdist_cuda101-0.0.4-py3-none-any.whl (21.7 kB view details)

Uploaded Python 3

File details

Details for the file ont_seqdist_cuda101-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: ont_seqdist_cuda101-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 21.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.9

File hashes

Hashes for ont_seqdist_cuda101-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 16b266daeacb2bdf7238b753c98a3f62616dd2eafe3527046dbcfff3b029a94c
MD5 7391e9c86035f05692c555f605d7c6b7
BLAKE2b-256 d206e69b7c9a649e324a3e34a501f8ba1910067b52e2dcdd6e0a6f04c50ab3f9

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page