Skip to main content

Probability distributions over sequences in pytorch and cupy

Project description

Seqdist

Probability distributions over sequences in pytorch and cupy.

Install

pip install seqdist

How to use

Comparison against builtin pytorch implementation of the standard CTC loss:

sample_inputs = logits, targets, input_lengths, target_lengths = ctc.generate_sample_inputs(T_min=450, T_max=500, N=128, C=20, L_min=80, L_max=100)
print('pytorch loss: {:.4f}'.format(ctc.loss_pytorch(*sample_inputs)))
print('seqdist loss: {:.4f}'.format(ctc.loss_cupy(*sample_inputs)))
pytorch loss: 12.8080
seqdist loss: 12.8080

Speed comparison

Pytorch:

report(benchmark_fwd_bwd(ctc.loss_pytorch, *sample_inputs))
fwd: 4.79ms (4.17-5.33ms)
bwd: 9.69ms (8.33-10.88ms)
tot: 14.47ms (12.67-16.20ms)

Seqdist:

report(benchmark_fwd_bwd(ctc.loss_cupy, *sample_inputs))
fwd: 7.22ms (6.78-7.85ms)
bwd: 6.21ms (5.82-8.57ms)
tot: 13.43ms (12.63-16.41ms)

Alignments

betas = [0.1, 1.0, 10.]
alignments = {'beta={:.1f}'.format(beta): to_np(ctc.soft_alignments(*sample_inputs, beta=beta)) for beta in betas}
alignments['viterbi'] = to_np(ctc.viterbi_alignments(*sample_inputs))
fig, axs = plt.subplots(2, 2, figsize=(15, 8))
for (ax, (title, data)) in zip(np.array(axs).flatten(), alignments.items()):
    ax.imshow(data[:, 0].T, vmax=0.05);
    ax.set_title(title)  

png

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

ont_seqdist_cuda113-0.0.4-py3-none-any.whl (21.7 kB view details)

Uploaded Python 3

File details

Details for the file ont_seqdist_cuda113-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: ont_seqdist_cuda113-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 21.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.9

File hashes

Hashes for ont_seqdist_cuda113-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 fc8fad030d59c6adb9c76e4b3f2cf2079bed7465ec280efe11ff047a2a61879d
MD5 cb76cd54149041ccdc051d927abaec22
BLAKE2b-256 3f9b449c77b014767c3121f51a3a833cb57e18ee0005640341132a62b885f74d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page