Skip to main content

Probability distributions over sequences in pytorch and cupy

Project description

Seqdist

Probability distributions over sequences in pytorch and cupy.

Install

pip install seqdist

How to use

Comparison against builtin pytorch implementation of the standard CTC loss:

sample_inputs = logits, targets, input_lengths, target_lengths = ctc.generate_sample_inputs(T_min=450, T_max=500, N=128, C=20, L_min=80, L_max=100)
print('pytorch loss: {:.4f}'.format(ctc.loss_pytorch(*sample_inputs)))
print('seqdist loss: {:.4f}'.format(ctc.loss_cupy(*sample_inputs)))
pytorch loss: 12.8080
seqdist loss: 12.8080

Speed comparison

Pytorch:

report(benchmark_fwd_bwd(ctc.loss_pytorch, *sample_inputs))
fwd: 4.79ms (4.17-5.33ms)
bwd: 9.69ms (8.33-10.88ms)
tot: 14.47ms (12.67-16.20ms)

Seqdist:

report(benchmark_fwd_bwd(ctc.loss_cupy, *sample_inputs))
fwd: 7.22ms (6.78-7.85ms)
bwd: 6.21ms (5.82-8.57ms)
tot: 13.43ms (12.63-16.41ms)

Alignments

betas = [0.1, 1.0, 10.]
alignments = {'beta={:.1f}'.format(beta): to_np(ctc.soft_alignments(*sample_inputs, beta=beta)) for beta in betas}
alignments['viterbi'] = to_np(ctc.viterbi_alignments(*sample_inputs))
fig, axs = plt.subplots(2, 2, figsize=(15, 8))
for (ax, (title, data)) in zip(np.array(axs).flatten(), alignments.items()):
    ax.imshow(data[:, 0].T, vmax=0.05);
    ax.set_title(title)  

png

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

ont_seqdist_cuda113-0.0.4-py3-none-any.whl (21.7 kB view details)

Uploaded Python 3

File details

Details for the file ont_seqdist_cuda113-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: ont_seqdist_cuda113-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 21.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.9

File hashes

Hashes for ont_seqdist_cuda113-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 fc8fad030d59c6adb9c76e4b3f2cf2079bed7465ec280efe11ff047a2a61879d
MD5 cb76cd54149041ccdc051d927abaec22
BLAKE2b-256 3f9b449c77b014767c3121f51a3a833cb57e18ee0005640341132a62b885f74d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page