PyTorch bindings for Baidu Warp-CTC
Project description
torch-baidu-ctc
Pytorch bindings for Baidu's Warp-CTC. These bindings were inspired by SeanNaren's but these include some bug fixes, and offer some additional features.
import torch
from torch_baidu_ctc import ctc_loss, CTCLoss
# Activations. Shape T x N x D.
# T -> max number of frames/timesteps
# N -> minibatch size
# D -> number of output labels (including the CTC blank)
x = torch.rand(10, 3, 6)
# Target labels
y = torch.tensor(
[
# 1st sample
1, 1, 2, 5, 2,
# 2nd
1, 5, 2,
# 3rd
4, 4, 2, 3,
],
dtype=torch.int,
)
# Activations lengths
xs = torch.tensor([10, 6, 9], dtype=torch.int)
# Target lengths
ys = torch.tensor([5, 3, 4], dtype=torch.int)
# By default, the costs (negative log-likelihood) of all samples are summed.
# This is equivalent to:
# ctc_loss(x, y, xs, ys, average_frames=False, reduction="sum")
loss1 = ctc_loss(x, y, xs, ys)
# You can also average the cost of each sample among the number of frames.
# The averaged costs are then summed.
loss2 = ctc_loss(x, y, xs, ys, average_frames=True)
# Instead of summing the costs of each sample, you can perform
# other `reductions`: "none", "sum", or "mean"
#
# Return an array with the loss of each individual sample
losses = ctc_loss(x, y, xs, ys, reduction="none")
#
# Compute the mean of the individual losses
loss3 = ctc_loss(x, y, xs, ys, reduction="mean")
#
# First, normalize loss by number of frames, later average losses
loss4 = ctc_loss(x, y, xs, ys, average_frames=True, reduction="mean")
# Finally, there's also a nn.Module to use this loss.
ctc = CTCLoss(average_frames=True, reduction="mean", blank=0)
loss4_2 = ctc(x, y, xs, ys)
# Note: the `blank` option is also available for `ctc_loss`.
# By default it is 0.
Requirements
- C++11 compiler (tested with GCC 4.9).
- Python: 2.7, 3.5, 3.6, 3.7 (tested with versions 2.7, 3.5 and 3.6).
- PyTorch >= 1.1.0 (tested with version 1.1.0).
- For GPU support: CUDA Toolkit.
Installation
The installation process should be pretty straightforward assuming that you have correctly installed the required libraries and tools.
The setup process compiles the package from source, and will compile with CUDA support if this is available for PyTorch.
From Pypi (recommended)
pip install torch-baidu-ctc
From GitHub
git clone --recursive https://github.com/jpuigcerver/pytorch-baidu-ctc.git
cd pytorch-baidu-ctc
python setup.py build
python setup.py install
AVX512 related issues
Some compiling problems may arise when using CUDA and newer host compilers
with AVX512 instructions. Please, install GCC 4.9 and use it as the host
compiler for NVCC. You can simply set the CC
and CXX
environment variables
before the build/install commands:
CC=gcc-4.9 CXX=g++-4.9 pip install torch-baidu-ctc
or (if you are using the GitHub source code):
CC=gcc-4.9 CXX=g++-4.9 python setup.py build
Testing
You can test the library once installed using unittest
. In particular,
run the following commands:
python -m unittest torch_baidu_ctc.test
All tests should pass (CUDA tests are only executed if supported).
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
File details
Details for the file torch-baidu-ctc-0.3.0.tar.gz
.
File metadata
- Download URL: torch-baidu-ctc-0.3.0.tar.gz
- Upload date:
- Size: 54.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.6.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1675682ca93d0915bfd54591b978bdd4bb78f12f73b9190326be29825b5e6e36 |
|
MD5 | 438a84138d42b063fb2e0468e346255b |
|
BLAKE2b-256 | 46aab9a7c86f1fae5754f8fe108e05e9f4340a5aec1ff0d5afefcc55c91a6ff4 |
File details
Details for the file torch_baidu_ctc-0.3.0-cp37-cp37m-manylinux1_x86_64.whl
.
File metadata
- Download URL: torch_baidu_ctc-0.3.0-cp37-cp37m-manylinux1_x86_64.whl
- Upload date:
- Size: 2.7 MB
- Tags: CPython 3.7m
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.6.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2518a0a5d651f474e817e88d1b20093499e701067290d445154693b31c3dfeec |
|
MD5 | 54f1892e8f7e3a7874645085379f29b6 |
|
BLAKE2b-256 | 182717d47d7eded8cfcbf6fc75b9281d5a0c5a6cb6da43161be539cd16eca154 |
File details
Details for the file torch_baidu_ctc-0.3.0-cp36-cp36m-manylinux1_x86_64.whl
.
File metadata
- Download URL: torch_baidu_ctc-0.3.0-cp36-cp36m-manylinux1_x86_64.whl
- Upload date:
- Size: 2.7 MB
- Tags: CPython 3.6m
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.6.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 84cfdcc044f044a1c71a120eac84e58f0e7a67d636433c2295bc0f15f3c2a97a |
|
MD5 | fd7e267da5ab8537f6f3f6b67e693881 |
|
BLAKE2b-256 | 547feadf62c7e541430ecb7dab16adabb7bdb757dafcbfc584e61c20187ed407 |
File details
Details for the file torch_baidu_ctc-0.3.0-cp35-cp35m-manylinux1_x86_64.whl
.
File metadata
- Download URL: torch_baidu_ctc-0.3.0-cp35-cp35m-manylinux1_x86_64.whl
- Upload date:
- Size: 2.7 MB
- Tags: CPython 3.5m
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.6.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7a0e4da260d6d6e84351095ae00a8e333bf9a579d3825b200d7c9de383364edc |
|
MD5 | d379789ed013493635c9292ed3728fe1 |
|
BLAKE2b-256 | 5836989ef546ae94e276f4f31ec55d1ffa7d3b790b6f8d69cbe8d63aec137e5a |
File details
Details for the file torch_baidu_ctc-0.3.0-cp27-cp27mu-manylinux1_x86_64.whl
.
File metadata
- Download URL: torch_baidu_ctc-0.3.0-cp27-cp27mu-manylinux1_x86_64.whl
- Upload date:
- Size: 2.7 MB
- Tags: CPython 2.7mu
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.6.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | db81e32f5e630ae91586328333222a9127eee747c9b72f3cca79c8b57729381b |
|
MD5 | 87ff4f047aaf83cb9a779798c4d7e4e1 |
|
BLAKE2b-256 | 7d890275e493bcf97542363dba7481dd4e979a39a01753a605bd3de2b940516a |