Skip to main content

CTC with tensorflow, simplified

Project description

Tensorflow CTC

Tools to simplify CTC in tensorflow

Mocked Logits

'Perfect' logits to match any given labels. Basically, a large negative value ($10^{-9}$ by default) everywhere except $0$ at the appropriate index. Plus, blank characters (by default at index 0) interspersed to prevent collapse of equal, consecutive labels

import tensorflow as tf
import tf.ctc as ctc

labs = tf.sparse.from_dense([[1, 2, 3, 0, 0],
                             [4, 3, 1, 2, 0],
                             [3, 1, 0, 0, 0]])

logits = ctc.onehot_logits(labs) # 'perfect' logits

# <tf.Tensor: shape=(3, 10, 5), dtype=float32, numpy=
#
# array([[[-1.e+09,  0.e+00, -1.e+09, -1.e+09, -1.e+09],
#         [ 0.e+00, -1.e+09, -1.e+09, -1.e+09, -1.e+09],
#         [-1.e+09, -1.e+09,  0.e+00, -1.e+09, -1.e+09],
#         [ 0.e+00, -1.e+09, -1.e+09, -1.e+09, -1.e+09],
#         [-1.e+09, -1.e+09, -1.e+09,  0.e+00, -1.e+09],
#         [ 0.e+00, -1.e+09, -1.e+09, -1.e+09, -1.e+09],
#         [ 0.e+00, -1.e+09, -1.e+09, -1.e+09, -1.e+09],
#         [ 0.e+00, -1.e+09, -1.e+09, -1.e+09, -1.e+09],
#         [ 0.e+00, -1.e+09, -1.e+09, -1.e+09, -1.e+09],
#         [ 0.e+00, -1.e+09, -1.e+09, -1.e+09, -1.e+09]],
# 
#        [[-1.e+09, -1.e+09, -1.e+09, -1.e+09,  0.e+00],
#           ...
#         [ 0.e+00, -1.e+09, -1.e+09, -1.e+09, -1.e+09]],
# 
#        [[-1.e+09, -1.e+09, -1.e+09,  0.e+00, -1.e+09],
#           ...
#         [ 0.e+00, -1.e+09, -1.e+09, -1.e+09, -1.e+09]]], dtype=float32)>

ctc.loss(labs, logits) # something very close to [0, 0, 0]

Loss

Wrapper around tf.nn.loss but labels must be a SparseTensor (as they should probably be) and the API is trivial (just labels and logits)

(See example above)

Decoding

Wrappers around tf.nn.ctc_greedy_decoder and tf.nn.ctc_beam_search_decoder, except:

  • Beam Search supports setting the blank index to 0 (and does so by default)
  • Logits are batch-major (as you most likely already have them)
  • Trivial API (just pass the logits and optionally config)
import tensorflow as tf
import tf_ctc as ctc

labs = tf.sparse.from_dense([[1, 2, 3, 0, 0],
                             [4, 3, 1, 2, 0],
                             [3, 1, 0, 0, 0]])

logits = ctc.onehot_logits(labs) # 'perfect' logits

[top_path, *_], log_probs = ctc.beam_decode(logits)
# or
[top_path], log_probs = ctc.greedy_decode(logits)

tf.sparse.to_dense(top_path)
# <tf.Tensor: shape=(3, 4), dtype=int64, numpy=
# array([[1, 2, 3, 0],
#        [4, 3, 1, 2],
#        [3, 1, 0, 0]])>

Metrics

Generalizations of accuracy and edit distance (default to $k=1$, i.e. the usual, concrete versions):

  • Top-$k$ accuracy: proportion of samples where the label is in the top-$k$ predictions
  • Top-$k$ edit distance: minimum edit distance between the label and each of the top-$k$ predictions; averaged across all samples
import tensorflow as tf
import tf_ctc as ctc

labs = tf.sparse.from_dense([[1, 2, 3, 0, 0],
                             [4, 3, 1, 2, 0],
                             [3, 1, 0, 0, 0]])

logits = ctc.onehot_logits(labs) # 'perfect' logits

ctc.accuracy(labs, logits) # 1.0
ctc.edit_distance(labs, logits) # 0.0

Testing

Very simple randomized tests using ctc.onehot_logits

pip install tf-ctc[test]
python -m tf.ctc.test

(Note: tests may spit out a warning about jaxtyping. That's just fine)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tf_ctc-0.1.8.tar.gz (7.4 kB view details)

Uploaded Source

Built Distribution

tf_ctc-0.1.8-py3-none-any.whl (8.5 kB view details)

Uploaded Python 3

File details

Details for the file tf_ctc-0.1.8.tar.gz.

File metadata

  • Download URL: tf_ctc-0.1.8.tar.gz
  • Upload date:
  • Size: 7.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.11.6

File hashes

Hashes for tf_ctc-0.1.8.tar.gz
Algorithm Hash digest
SHA256 31a39c1167c80dad6fe0af211773d421127eaf22a3439a9b810b4bf1ce33eeee
MD5 628a60faa2c1aabbb6a150c7bd343c35
BLAKE2b-256 9e20fd5fec67313d3fb2928b78607264b03e15ebb0b9a8aaf8982702d4da28ab

See more details on using hashes here.

File details

Details for the file tf_ctc-0.1.8-py3-none-any.whl.

File metadata

  • Download URL: tf_ctc-0.1.8-py3-none-any.whl
  • Upload date:
  • Size: 8.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.11.6

File hashes

Hashes for tf_ctc-0.1.8-py3-none-any.whl
Algorithm Hash digest
SHA256 6682287015245aaa9848254f11e6d9c0bf1e68e5da494e7855d7e90aba511bd2
MD5 0b324dde86ec00f0d092ffeea92d74e4
BLAKE2b-256 44208b7ed4fa4792932e4a6252dd974fe43c02a4ede06473b49c954ccfa6c197

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page