CTC with tensorflow, simplified
Project description
Tensorflow CTC
Tools to simplify CTC in
tensorflow
Mocked Logits
'Perfect' logits to match any given labels. Basically, a large negative value ($10^{-9}$ by default) everywhere except $0$ at the appropriate index. Plus, blank characters (by default at index 0) interspersed to prevent collapse of equal, consecutive labels
import tensorflow as tf
import tf.ctc as ctc
labs = tf.sparse.from_dense([[1, 2, 3, 0, 0],
[4, 3, 1, 2, 0],
[3, 1, 0, 0, 0]])
logits = ctc.onehot_logits(labs) # 'perfect' logits
# <tf.Tensor: shape=(3, 10, 5), dtype=float32, numpy=
#
# array([[[-1.e+09, 0.e+00, -1.e+09, -1.e+09, -1.e+09],
# [ 0.e+00, -1.e+09, -1.e+09, -1.e+09, -1.e+09],
# [-1.e+09, -1.e+09, 0.e+00, -1.e+09, -1.e+09],
# [ 0.e+00, -1.e+09, -1.e+09, -1.e+09, -1.e+09],
# [-1.e+09, -1.e+09, -1.e+09, 0.e+00, -1.e+09],
# [ 0.e+00, -1.e+09, -1.e+09, -1.e+09, -1.e+09],
# [ 0.e+00, -1.e+09, -1.e+09, -1.e+09, -1.e+09],
# [ 0.e+00, -1.e+09, -1.e+09, -1.e+09, -1.e+09],
# [ 0.e+00, -1.e+09, -1.e+09, -1.e+09, -1.e+09],
# [ 0.e+00, -1.e+09, -1.e+09, -1.e+09, -1.e+09]],
#
# [[-1.e+09, -1.e+09, -1.e+09, -1.e+09, 0.e+00],
# ...
# [ 0.e+00, -1.e+09, -1.e+09, -1.e+09, -1.e+09]],
#
# [[-1.e+09, -1.e+09, -1.e+09, 0.e+00, -1.e+09],
# ...
# [ 0.e+00, -1.e+09, -1.e+09, -1.e+09, -1.e+09]]], dtype=float32)>
ctc.loss(labs, logits) # something very close to [0, 0, 0]
Loss
Wrapper around tf.nn.loss
but labels must be a SparseTensor
(as they should probably be) and the API is trivial (just labels and logits)
(See example above)
Decoding
Wrappers around tf.nn.ctc_greedy_decoder
and tf.nn.ctc_beam_search_decoder
, except:
- Beam Search supports setting the blank index to 0 (and does so by default)
- Logits are batch-major (as you most likely already have them)
- Trivial API (just pass the logits and optionally config)
import tensorflow as tf
import tf_ctc as ctc
labs = tf.sparse.from_dense([[1, 2, 3, 0, 0],
[4, 3, 1, 2, 0],
[3, 1, 0, 0, 0]])
logits = ctc.onehot_logits(labs) # 'perfect' logits
[top_path, *_], log_probs = ctc.beam_decode(logits)
# or
[top_path], log_probs = ctc.greedy_decode(logits)
tf.sparse.to_dense(top_path)
# <tf.Tensor: shape=(3, 4), dtype=int64, numpy=
# array([[1, 2, 3, 0],
# [4, 3, 1, 2],
# [3, 1, 0, 0]])>
Metrics
Generalizations of accuracy and edit distance (default to $k=1$, i.e. the usual, concrete versions):
- Top-$k$ accuracy: proportion of samples where the label is in the top-$k$ predictions
- Top-$k$ edit distance: minimum edit distance between the label and each of the top-$k$ predictions; averaged across all samples
import tensorflow as tf
import tf_ctc as ctc
labs = tf.sparse.from_dense([[1, 2, 3, 0, 0],
[4, 3, 1, 2, 0],
[3, 1, 0, 0, 0]])
logits = ctc.onehot_logits(labs) # 'perfect' logits
ctc.accuracy(labs, logits) # 1.0
ctc.edit_distance(labs, logits) # 0.0
Testing
Very simple randomized tests using ctc.onehot_logits
pip install tf-ctc[test]
python -m tf.ctc.test
(Note: tests may spit out a warning about jaxtyping
. That's just fine)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file tf_ctc-0.1.8.tar.gz
.
File metadata
- Download URL: tf_ctc-0.1.8.tar.gz
- Upload date:
- Size: 7.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.11.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 31a39c1167c80dad6fe0af211773d421127eaf22a3439a9b810b4bf1ce33eeee |
|
MD5 | 628a60faa2c1aabbb6a150c7bd343c35 |
|
BLAKE2b-256 | 9e20fd5fec67313d3fb2928b78607264b03e15ebb0b9a8aaf8982702d4da28ab |
File details
Details for the file tf_ctc-0.1.8-py3-none-any.whl
.
File metadata
- Download URL: tf_ctc-0.1.8-py3-none-any.whl
- Upload date:
- Size: 8.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.11.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6682287015245aaa9848254f11e6d9c0bf1e68e5da494e7855d7e90aba511bd2 |
|
MD5 | 0b324dde86ec00f0d092ffeea92d74e4 |
|
BLAKE2b-256 | 44208b7ed4fa4792932e4a6252dd974fe43c02a4ede06473b49c954ccfa6c197 |