FSA/FST algorithms, intended to (eventually) be interoperable with PyTorch and similar
Project description
k2
The vision of k2 is to be able to seamlessly integrate Finite State Automaton (FSA) and Finite State Transducer (FST) algorithms into autograd-based machine learning toolkits like PyTorch and TensorFlow. For speech recognition applications, this should make it easy to interpolate and combine various training objectives such as cross-entropy, CTC and MMI and to jointly optimize a speech recognition system with multiple decoding passes including lattice rescoring and confidence estimation. We hope k2 will have many other applications as well.
One of the key algorithms that we have implemented is pruned composition of a generic FSA with a "dense" FSA (i.e. one that corresponds to log-probs of symbols at the output of a neural network). This can be used as a fast implementation of decoding for ASR, and for CTC and LF-MMI training. This won't give a direct advantage in terms of Word Error Rate when compared with existing technology; but the point is to do this in a much more general and extensible framework to allow further development of ASR technology.
Implementation
A few key points on our implementation strategy.
Most of the code is in C++ and CUDA. We implement a templated class Ragged
,
which is quite like TensorFlow's RaggedTensor
(actually we came up with the
design independently, and were later told that TensorFlow was using the same
ideas). Despite a close similarity at the level of data structures, the
design is quite different from TensorFlow and PyTorch. Most of the time we
don't use composition of simple operations, but rely on C++11 lambdas defined
directly in the C++ implementations of algorithms. The code in these lambdas operate
directly on data pointers and, if the backend is CUDA, they can run in parallel
for each element of a tensor. (The C++ and CUDA code is mixed together and the
CUDA kernels get instantiated via templates).
It is difficult to adequately describe what we are doing with these Ragged
objects without going in detail through the code. The algorithms look very
different from the way you would code them on CPU because of the need to avoid
sequential processing. We are using coding patterns that make the most
expensive parts of the computations "embarrassingly parallelizable"; the only
somewhat nontrivial CUDA operations are generally reduction-type operations
such as exclusive-prefix-sum, for which we use NVidia's cub
library. Our
design is not too specific to the NVidia hardware and the bulk of the code we
write is fairly normal-looking C++; the nontrivial CUDA programming is mostly
done via the cub library, parts of which we wrap with our own convenient
interface.
The Finite State Automaton object is then implemented as a Ragged tensor templated on a specific data type (a struct representing an arc in the automaton).
Autograd
If you look at the code as it exists now, you won't find any references to autograd. The design is quite different to TensorFlow and PyTorch (which is why we didn't simply extend one of those toolkits). Instead of making autograd come from the bottom up (by making individual operations differentiable) we are implementing it from the top down, which is much more efficient in this case (and will tend to have better roundoff properties).
An example: suppose we are finding the best path of an FSA, and we need derivatives. We implement this by keeping track of, for each arc in the output best-path, which input arc it corresponds to. (For more complex algorithms an arc in the output might correspond to a sum of probabilities of a list of input arcs). We can make this compatible with PyTorch/TensorFlow autograd at the Python level, by, for example, defining a Function class in PyTorch that remembers this relationship between the arcs and does the appropriate (sparse) operations to propagate back the derivatives w.r.t. the weights.
Current state of the code
We have wrapped all the C++ code to Python with pybind11 and have finished the integration with PyTorch.
We are currently writing speech recognition recipes using k2, which are hosted in a separate repository. Please see https://github.com/k2-fsa/icefall.
Plans after initial release
We are currently trying to make k2 ready for production use (see the branch v2.0-pre).
Quick start
Want to try it out without installing anything? We have setup a Google Colab. You can find more Colab notebooks using k2 in speech recognition at https://icefall.readthedocs.io/en/latest/recipes/librispeech/conformer_ctc.html.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distributions
File details
Details for the file k2-1.19-py38-none-any.whl
.
File metadata
- Download URL: k2-1.19-py38-none-any.whl
- Upload date:
- Size: 72.8 MB
- Tags: Python 3.8
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.8.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3c33085c3ff9b08d8a280177858bac8527bb727b640c8f568f37940afe52fc83 |
|
MD5 | dee0f53d4c60f09856b2d31ccd37ec11 |
|
BLAKE2b-256 | 954934cbb974190388b0ef541ac5589a915b73c97fee5e6dfc831efeb894ce3c |
Provenance
File details
Details for the file k2-1.19-py37-none-any.whl
.
File metadata
- Download URL: k2-1.19-py37-none-any.whl
- Upload date:
- Size: 72.8 MB
- Tags: Python 3.7
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.7.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5d4b08c61c8bd60072d3c8099acf0b986f051422fbd1eecb6629622668c19b0b |
|
MD5 | e0203359b1ba73960a1e5aaf0bf25114 |
|
BLAKE2b-256 | 66eef70ed4e0389e02b4e556270021fa5c2203e12a37e97a78fd8baa15b0ded4 |
Provenance
File details
Details for the file k2-1.19-py36-none-any.whl
.
File metadata
- Download URL: k2-1.19-py36-none-any.whl
- Upload date:
- Size: 72.8 MB
- Tags: Python 3.6
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.8.0 pkginfo/1.8.3 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.12 tqdm/4.64.0 importlib-metadata/4.8.3 keyring/23.4.1 rfc3986/1.5.0 colorama/0.4.5 CPython/3.6.15
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 704986a035c2e977a68422ff80d1c5ab114231f998af6a700d9cf09acc0c89fb |
|
MD5 | 9f969af78f7d3bd85e3fc08c8f2f07fc |
|
BLAKE2b-256 | 3f58bef2db2a56c935615d447b8195f639e457a7a9118a6cdd408622d7fb055e |
Provenance
File details
Details for the file k2-1.19-cp38-cp38-macosx_10_15_x86_64.whl
.
File metadata
- Download URL: k2-1.19-cp38-cp38-macosx_10_15_x86_64.whl
- Upload date:
- Size: 1.9 MB
- Tags: CPython 3.8, macOS 10.15+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.8.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 511b4f5ebf09fdc5c3addf4d2bb4ae84506ac1bbed9a3512fbc33163648cb171 |
|
MD5 | a250d7abdf4c40b7fa28ec5d59e58b58 |
|
BLAKE2b-256 | 695d076d9e6aba327c5ffc574b062564b2e84d2bc7b6843211d22d998488a27f |
Provenance
File details
Details for the file k2-1.19-cp37-cp37m-macosx_10_15_x86_64.whl
.
File metadata
- Download URL: k2-1.19-cp37-cp37m-macosx_10_15_x86_64.whl
- Upload date:
- Size: 1.9 MB
- Tags: CPython 3.7m, macOS 10.15+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.7.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b5ec97cf37db80ab838851fe6c9b3b360828124dd7a5fb769b561c5631f2c1c2 |
|
MD5 | d5e47cb372fe2cee0c12e19e70db74f4 |
|
BLAKE2b-256 | 5964bc08e6a67b9bcfff212d34f9f0a3d6d054ba2d40bd60f63dcff2d4ed186d |
Provenance
File details
Details for the file k2-1.19-cp36-cp36m-macosx_10_15_x86_64.whl
.
File metadata
- Download URL: k2-1.19-cp36-cp36m-macosx_10_15_x86_64.whl
- Upload date:
- Size: 1.9 MB
- Tags: CPython 3.6m, macOS 10.15+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.8.0 pkginfo/1.8.3 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.12 tqdm/4.64.0 importlib-metadata/4.8.3 keyring/23.4.1 rfc3986/1.5.0 colorama/0.4.5 CPython/3.6.15
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 01fc3a67fb411ee0e49b7b22eef87ea74113b2b5f0f4e85a4b57ea8d48eb0146 |
|
MD5 | 310c1af11fd26907ccd51d7b4907ba8b |
|
BLAKE2b-256 | 5ed7490b328cf81d3afa2bd32b6bfb3f7569fe0e97f6fafde98f7d0039e2ff65 |