Skip to main content

Probability distributions over sequences in pytorch and cupy

Project description

Seqdist

Probability distributions over sequences in pytorch and cupy.

Install

pip install seqdist

How to use

Comparison against builtin pytorch implementation of the standard CTC loss:

sample_inputs = logits, targets, input_lengths, target_lengths = ctc.generate_sample_inputs(T_min=450, T_max=500, N=128, C=20, L_min=80, L_max=100)
print('pytorch loss: {:.4f}'.format(ctc.loss_pytorch(*sample_inputs)))
print('seqdist loss: {:.4f}'.format(ctc.loss_cupy(*sample_inputs)))
pytorch loss: 12.8080
seqdist loss: 12.8080

Speed comparison

Pytorch:

report(benchmark_fwd_bwd(ctc.loss_pytorch, *sample_inputs))
fwd: 4.79ms (4.17-5.33ms)
bwd: 9.69ms (8.33-10.88ms)
tot: 14.47ms (12.67-16.20ms)

Seqdist:

report(benchmark_fwd_bwd(ctc.loss_cupy, *sample_inputs))
fwd: 7.22ms (6.78-7.85ms)
bwd: 6.21ms (5.82-8.57ms)
tot: 13.43ms (12.63-16.41ms)

Alignments

betas = [0.1, 1.0, 10.]
alignments = {'beta={:.1f}'.format(beta): to_np(ctc.soft_alignments(*sample_inputs, beta=beta)) for beta in betas}
alignments['viterbi'] = to_np(ctc.viterbi_alignments(*sample_inputs))
fig, axs = plt.subplots(2, 2, figsize=(15, 8))
for (ax, (title, data)) in zip(np.array(axs).flatten(), alignments.items()):
    ax.imshow(data[:, 0].T, vmax=0.05);
    ax.set_title(title)  

png

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

ont_seqdist_cuda102-0.0.4-py3-none-any.whl (21.7 kB view details)

Uploaded Python 3

File details

Details for the file ont_seqdist_cuda102-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: ont_seqdist_cuda102-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 21.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.9

File hashes

Hashes for ont_seqdist_cuda102-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 9797d772e09a08bffc2664de356bd47025c169eb43fced5f88539358219180a1
MD5 717df57bcf36724c848872ce9b8617db
BLAKE2b-256 c5047d7592652dd69926a0d6d1b04a40f63989b609664861206a161392138a6e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page