Skip to main content

Tensorflow implementations for seq2seq Machine Learning model loss functions

Project description

tf-seq2seq-losses

Tensorflow implementations for Connectionist Temporal Classification (CTC) loss in TensorFlow.

Installation

Tested with Python 3.7.

$ pip install tf-seq2seq-losses

Why

1. Faster

official CTC loss implementation tf.nn.ctc_loss is dramatically slow. The proposed implementation is approximately 15 times faster as it follows form the benchmark:

Name forward time gradient calculation time
tf.nn.ctc_loss 13.2 ± 0.02 10.4 ± 3.
classic_ctc_loss 0.138 ± 0.006 0.28 ± 0.01
simple_ctc_loss 0.0531 ± 0.003 0.119 ± 0.004

(Tested on single GPU: GeForce GTX 970, Driver Version: 460.91.03, CUDA Version: 11.2). See benchmark.py for the experimental setup. To run this benchmark use

$ pytest -o log_cli=true --log-level=INFO tests/benchmark.py

from the project root directory. Here classic_ctc_loss is the standard version of CTC loss that corresponds to the decoding with repeated tokens collapse like

a_bb_ccc_c -> abcc

(equivalent to tensorflow.nn.ctc_loss). The loss function simple_ctc_loss is a simplified version corresponding to straight decoding rule like

a_bb_ccc_c -> abbcccc

(simple blank removing).

2. Numerically stable

  1. Proposed implementation is more numerically stable, for example it calculates resonable output for logits of order 1e+10 and even for -tf.inf.
  2. If logit length is too short to predict label output the probability of expected prediction is zero. Thus, the loss output is -tf.inf for this sample but not 702. like fortf.nn.ctc_loss.

3. No C++ compilation

This is a pure Python/TensorFlow implementation. We do not have to build or compile any C++/CUDA stuff.

Usage

The interface is identical to tensorflow.nn.ctc_loss with logits_time_major=False.

import tensorflow as tf
from tf_seq2seq_losses import classic_ctc_loss

batch_size = 1
num_token = 3 # = 2 tokens + blank
logit_length = 5
loss = classic_ctc_loss(
    labels = tf.constant([[1,2,2,1]], dtype=tf.int32),
    logits = tf.zeros(shape=[batch_size, logit_length, num_token], dtype=tf.float32),
    label_length = tf.constant([4], dtype=tf.int32),
    logit_length = tf.constant([logit_length], dtype=tf.int32),
    blank_index = 0,
)

Under the roof

TensorFlow operations sich as tf.while_loop and tf.TensorArray. The bottleneck is the iterations over the logit length in order to calculate α and β (see, for example, the original paper ). Expected gradient GPU calculation time is linear over logit length.

Known Probelems:

  1. Warning:

AutoGraph could not transform <function classic_ctc_loss at ...> and will run it as-is. Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, export AUTOGRAPH_VERBOSITY=10) and attach the full output.

observed for tensorflow version 2.4.1 has no effect for performance. It is caused by Union in type annotations.

Future plans

  1. Add decoding (inference)
  2. Add rnn-t loss.
  3. Add m-wer loss.

grad:

remove union +------------------------+----------------------------+-------+ | | mean processing time (s) | std | +========================+============================+=======+ | tensorflow.nn.ctc_loss | 9.42 | 0.1 | +------------------------+----------------------------+-------+ | classic_ctc_loss | 0.271 | 0.005 | +------------------------+----------------------------+-------+ | simple_ctc_loss | 0.121 | 0.006 | +------------------------+----------------------------+-------+

+------------------------+----------------------------+-------+ | | mean processing time (s) | std | +========================+============================+=======+ | tensorflow.nn.ctc_loss | 23.3 | 5 | +------------------------+----------------------------+-------+ | classic_ctc_loss | 0.27 | 0.003 | +------------------------+----------------------------+-------+ | simple_ctc_loss | 0.117 | 0.004 | +------------------------+----------------------------+-------+

tensorflow.nn.ctc_loss 18.582688 0.241295 classic_ctc_loss 0.548777 0.004851 simple_ctc_loss 0.308365 0.005894

tensorflow.nn.ctc_loss 8.712257 0.142457 classic_ctc_loss 0.554044 0.012687 simple_ctc_loss 0.294473 0.002924

tensorflow.nn.ctc_loss 9.619491 0.023562 classic_ctc_loss 0.341878 0.006051 simple_ctc_loss 0.290321 0.004223

forward:

remove union +------------------------+----------------------------+-------+ | | mean processing time (s) | std | +========================+============================+=======+ | tensorflow.nn.ctc_loss | 15.1 | 3 | +------------------------+----------------------------+-------+ | classic_ctc_loss | 0.135 | 0.001 | +------------------------+----------------------------+-------+ | simple_ctc_loss | 0.053 | 0.003 | +------------------------+----------------------------+-------+

+------------------------+----------------------------+--------+ | | mean processing time (s) | std | +========================+============================+========+ | tensorflow.nn.ctc_loss | 12.3 | 1 | +------------------------+----------------------------+--------+ | classic_ctc_loss | 0.134 | 0.0009 | +------------------------+----------------------------+--------+ | simple_ctc_loss | 0.0547 | 0.003 | +------------------------+----------------------------+--------+

tensorflow.nn.ctc_loss 8.421478 0.039425 classic_ctc_loss 0.086655 0.003825 simple_ctc_loss 0.059665 0.002828

tensorflow.nn.ctc_loss 24.156974 5.085383 classic_ctc_loss 0.089546 0.002290 simple_ctc_loss 0.062101 0.003325

classic_ctc_loss 0.089270 0.002333 simple_ctc_loss 0.060012 0.001957 tensorflow.nn.ctc_loss 18.544750 3.352792

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tf-seq2seq-losses-0.1.0.tar.gz (19.0 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page