Skip to main content

Tensorflow implementations for (CTC) loss functions that are fast and support second-order derivatives.

Project description

tf-seq2seq-losses

Tensorflow implementations for Connectionist Temporal Classification (CTC) loss functions that are fast and support second-order derivatives.

Installation

$ pip install tf-seq2seq-losses

Why Use This Package?

1. Faster Performance

Official CTC loss implementation, tf.nn.ctc_loss, is significantly slower. Our implementation is approximately 30 times faster, as shown by the benchmark results:

Name Forward Time (ms) Gradient Calculation Time (ms)
tf.nn.ctc_loss 13.2 ± 0.02 10.4 ± 3
classic_ctc_loss 0.138 ± 0.006 0.28 ± 0.01
simple_ctc_loss 0.0531 ± 0.003 0.119 ± 0.004

Tested on a single GPU: GeForce GTX 970, Driver Version: 460.91.03, CUDA Version: 11.2. For the experimental setup, see benchmark.py To reproduce this benchmark, run the following command from the project root directory (install pytest and pandas if needed):

$ pytest -o log_cli=true --log-level=INFO tests/benchmark.py

Here, classic_ctc_loss is the standard version of CTC loss with token collapsing, e.g., a_bb_ccc_c -> abcc. The simple_ctc_loss is a simplified version that removes blanks trivially, e.g., a_bb_ccc_c -> abbcccc.

2. Supports Second-Order Derivatives

This implementation supports second-order derivatives without using TensorFlow's autogradient. Instead, it uses a custom approach similar to the one described here with a complexity of $O(l^4)$, where $l$ is the sequence length. The gradient complexity is $O(l^2)$.

Example usage:

import tensorflow as tf
from tf_seq2seq_losses import classic_ctc_loss 

batch_size = 2
num_tokens = 3
logit_length = 5
labels = tf.constant([[1, 2, 2, 1], [1, 2, 1, 0]], dtype=tf.int32)
label_length = tf.constant([4, 3], dtype=tf.int32)
logits = tf.zeros(shape=[batch_size, logit_length, num_tokens], dtype=tf.float32)
logit_length = tf.constant([5, 4], dtype=tf.int32)

with tf.GradientTape(persistent=True) as tape1: 
    tape1.watch([logits])
    with tf.GradientTape() as tape2:
        tape2.watch([logits])
        loss = tf.reduce_sum(classic_ctc_loss(
            labels=labels,
            logits=logits,
            label_length=label_length,
            logit_length=logit_length,
            blank_index=0,
        ))
    gradient = tape2.gradient(loss, sources=logits)
hessian = tape1.batch_jacobian(gradient, source=logits, experimental_use_pfor=False)
# shape = [2, 5, 3, 5, 3]

3. Numerical Stability

  1. The proposed implementation is more numerically stable, producing reasonable outputs even for logits of order 1e+10 and -tf.inf.
  2. If the logit length is too short to predict the label output, the loss is tf.inf for that sample, unlike tf.nn.ctc_loss, which might output 707.13184.

4. Pure Python Implementation

This is a pure Python/TensorFlow implementation, eliminating the need to build or compile any C++/CUDA components.

Usage

The interface is identical to tensorflow.nn.ctc_loss with logits_time_major=False.

Example:

import tensorflow as tf
from tf_seq2seq_losses import classic_ctc_loss

batch_size = 1
num_tokens = 3 # = 2 tokens + 1 blank token
logit_length = 5
loss = classic_ctc_loss(
    labels=tf.constant([[1, 2, 2, 1]], dtype=tf.int32),
    logits=tf.zeros(shape=[batch_size, logit_length, num_tokens], dtype=tf.float32),
    label_length=tf.constant([4], dtype=tf.int32),
    logit_length=tf.constant([logit_length], dtype=tf.int32),
    blank_index=0,
)

Under the Hood

The implementation uses TensorFlow operations such as tf.while_loop and tf.TensorArray. The main computational bottleneck is the iteration over the logit length to calculate α and β (as described in the original CTC paper). The expected gradient GPU calculation time is linear with respect to the logit length.

Known Issues

1. Warning:

AutoGraph could not transform <function classic_ctc_loss at ...> and will run it as-is. Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, export AUTOGRAPH_VERBOSITY=10) and attach the full output.

Observed with TensorFlow version 2.4.1. This warning does not affect performance and is caused by the use of Union in type annotations.

2. UnimplementedError:

Using tf.jacobian and tf.batch_jacobian for the second derivative of classic_ctc_loss with experimental_use_pfor=False in tf.GradientTape may cause an unexpected UnimplementedError in TensorFlow version 2.4.1 or later. This can be avoided by setting experimental_use_pfor=True or by using ClassicCtcLossData.hessian directly without tf.GradientTape.

Feel free to reach out if you have any questions or need further clarification.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tf_seq2seq_losses-0.3.0.tar.gz (28.1 kB view details)

Uploaded Source

Built Distribution

tf_seq2seq_losses-0.3.0-py3-none-any.whl (26.1 kB view details)

Uploaded Python 3

File details

Details for the file tf_seq2seq_losses-0.3.0.tar.gz.

File metadata

  • Download URL: tf_seq2seq_losses-0.3.0.tar.gz
  • Upload date:
  • Size: 28.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.9.19

File hashes

Hashes for tf_seq2seq_losses-0.3.0.tar.gz
Algorithm Hash digest
SHA256 f03d225675115746670efa7ee5410be5757533c8129a43b8bc9fd40419f3e134
MD5 758c9dd2909ed5ca6cae3a3b913d3aa9
BLAKE2b-256 35532d759f7d7f0003afe07f1c1cfbf2fb70f141443f65ad263c861639d38087

See more details on using hashes here.

File details

Details for the file tf_seq2seq_losses-0.3.0-py3-none-any.whl.

File metadata

File hashes

Hashes for tf_seq2seq_losses-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1ae12a3ede0bb96f1276de4c6784f9503fc5b28f99310bfb25ff3b1b147501bb
MD5 7ffd7292e4fb390c4af8ee3cc114c1ae
BLAKE2b-256 f3976a9c72a0159900f7be1fbd0dcc9e0bb53e69ab5095a2a97c6bb75cf8ed92

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page