Skip to main content

PyTorch bindings for Baidu Warp-CTC

Project description

torch-baidu-ctc

Build Status

Pytorch bindings for Baidu's Warp-CTC. These bindings were inspired by SeanNaren's but these include some bug fixes, and offer some additional features.

import torch
from torch_baidu_ctc import ctc_loss, CTCLoss

# Activations. Shape T x N x D.
# T -> max number of frames/timesteps
# N -> minibatch size
# D -> number of output labels (including the CTC blank)
x = torch.rand(10, 3, 6)
# Target labels
y = torch.tensor(
  [
    # 1st sample
    1, 1, 2, 5, 2,
    # 2nd
    1, 5, 2,
    # 3rd
    4, 4, 2, 3,
  ],
  dtype=torch.int,
)
# Activations lengths
xs = torch.tensor([10, 6, 9], dtype=torch.int)
# Target lengths
ys = torch.tensor([5, 3, 4], dtype=torch.int)

# By default, the costs (negative log-likelihood) of all samples are summed.
# This is equivalent to:
#   ctc_loss(x, y, xs, ys, average_frames=False, reduction="sum")
loss1 = ctc_loss(x, y, xs, ys)

# You can also average the cost of each sample among the number of frames.
# The averaged costs are then summed.
loss2 = ctc_loss(x, y, xs, ys, average_frames=True)

# Instead of summing the costs of each sample, you can perform
# other `reductions`: "none", "sum", or "mean"
#
# Return an array with the loss of each individual sample
losses = ctc_loss(x, y, xs, ys, reduction="none")
#
# Compute the mean of the individual losses
loss3 = ctc_loss(x, y, xs, ys, reduction="mean")
#
# First, normalize loss by number of frames, later average losses
loss4 = ctc_loss(x, y, xs, ys, average_frames=True, reduction="mean")


# Finally, there's also a nn.Module to use this loss.
ctc = CTCLoss(average_frames=True, reduction="mean", blank=0)
loss4_2 = ctc(x, y, xs, ys)

# Note: the `blank` option is also available for `ctc_loss`.
# By default it is 0.

Requirements

  • C++11 compiler (tested with GCC 4.9).
  • Python: 2.7, 3.5, 3.6, 3.7 (tested with versions 2.7, 3.5 and 3.6).
  • PyTorch >= 1.1.0 (tested with version 1.1.0).
  • For GPU support: CUDA Toolkit.

Installation

The installation process should be pretty straightforward assuming that you have correctly installed the required libraries and tools.

The setup process compiles the package from source, and will compile with CUDA support if this is available for PyTorch.

From Pypi (recommended)

pip install torch-baidu-ctc

From GitHub

git clone --recursive https://github.com/jpuigcerver/pytorch-baidu-ctc.git
cd pytorch-baidu-ctc
python setup.py build
python setup.py install

AVX512 related issues

Some compiling problems may arise when using CUDA and newer host compilers with AVX512 instructions. Please, install GCC 4.9 and use it as the host compiler for NVCC. You can simply set the CC and CXX environment variables before the build/install commands:

CC=gcc-4.9 CXX=g++-4.9 pip install torch-baidu-ctc

or (if you are using the GitHub source code):

CC=gcc-4.9 CXX=g++-4.9 python setup.py build

Testing

You can test the library once installed using unittest. In particular, run the following commands:

python -m unittest torch_baidu_ctc.test

All tests should pass (CUDA tests are only executed if supported).

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch-baidu-ctc-0.3.0.tar.gz (54.9 kB view details)

Uploaded Source

Built Distributions

File details

Details for the file torch-baidu-ctc-0.3.0.tar.gz.

File metadata

  • Download URL: torch-baidu-ctc-0.3.0.tar.gz
  • Upload date:
  • Size: 54.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.6.8

File hashes

Hashes for torch-baidu-ctc-0.3.0.tar.gz
Algorithm Hash digest
SHA256 1675682ca93d0915bfd54591b978bdd4bb78f12f73b9190326be29825b5e6e36
MD5 438a84138d42b063fb2e0468e346255b
BLAKE2b-256 46aab9a7c86f1fae5754f8fe108e05e9f4340a5aec1ff0d5afefcc55c91a6ff4

See more details on using hashes here.

File details

Details for the file torch_baidu_ctc-0.3.0-cp37-cp37m-manylinux1_x86_64.whl.

File metadata

  • Download URL: torch_baidu_ctc-0.3.0-cp37-cp37m-manylinux1_x86_64.whl
  • Upload date:
  • Size: 2.7 MB
  • Tags: CPython 3.7m
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.6.8

File hashes

Hashes for torch_baidu_ctc-0.3.0-cp37-cp37m-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 2518a0a5d651f474e817e88d1b20093499e701067290d445154693b31c3dfeec
MD5 54f1892e8f7e3a7874645085379f29b6
BLAKE2b-256 182717d47d7eded8cfcbf6fc75b9281d5a0c5a6cb6da43161be539cd16eca154

See more details on using hashes here.

File details

Details for the file torch_baidu_ctc-0.3.0-cp36-cp36m-manylinux1_x86_64.whl.

File metadata

  • Download URL: torch_baidu_ctc-0.3.0-cp36-cp36m-manylinux1_x86_64.whl
  • Upload date:
  • Size: 2.7 MB
  • Tags: CPython 3.6m
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.6.8

File hashes

Hashes for torch_baidu_ctc-0.3.0-cp36-cp36m-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 84cfdcc044f044a1c71a120eac84e58f0e7a67d636433c2295bc0f15f3c2a97a
MD5 fd7e267da5ab8537f6f3f6b67e693881
BLAKE2b-256 547feadf62c7e541430ecb7dab16adabb7bdb757dafcbfc584e61c20187ed407

See more details on using hashes here.

File details

Details for the file torch_baidu_ctc-0.3.0-cp35-cp35m-manylinux1_x86_64.whl.

File metadata

  • Download URL: torch_baidu_ctc-0.3.0-cp35-cp35m-manylinux1_x86_64.whl
  • Upload date:
  • Size: 2.7 MB
  • Tags: CPython 3.5m
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.6.8

File hashes

Hashes for torch_baidu_ctc-0.3.0-cp35-cp35m-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 7a0e4da260d6d6e84351095ae00a8e333bf9a579d3825b200d7c9de383364edc
MD5 d379789ed013493635c9292ed3728fe1
BLAKE2b-256 5836989ef546ae94e276f4f31ec55d1ffa7d3b790b6f8d69cbe8d63aec137e5a

See more details on using hashes here.

File details

Details for the file torch_baidu_ctc-0.3.0-cp27-cp27mu-manylinux1_x86_64.whl.

File metadata

  • Download URL: torch_baidu_ctc-0.3.0-cp27-cp27mu-manylinux1_x86_64.whl
  • Upload date:
  • Size: 2.7 MB
  • Tags: CPython 2.7mu
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.6.8

File hashes

Hashes for torch_baidu_ctc-0.3.0-cp27-cp27mu-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 db81e32f5e630ae91586328333222a9127eee747c9b72f3cca79c8b57729381b
MD5 87ff4f047aaf83cb9a779798c4d7e4e1
BLAKE2b-256 7d890275e493bcf97542363dba7481dd4e979a39a01753a605bd3de2b940516a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page