PyTorch bindings for CUDA-Warp RNN-Transducer
Project description
PyTorch bindings for CUDA-Warp RNN-Transducer
def rnnt_loss(log_probs: torch.FloatTensor,
labels: torch.IntTensor,
frames_lengths: torch.IntTensor,
labels_lengths: torch.IntTensor,
average_frames: bool = False,
reduction: Optional[AnyStr] = None,
blank: int = 0,
gather: bool = False,
fastemit_lambda: float = 0.0,
compact: bool = False) -> torch.Tensor:
"""The CUDA-Warp RNN-Transducer loss.
Args:
log_probs (torch.FloatTensor): Input tensor with shape (N, T, U, V)
where N is the minibatch size, T is the maximum number of
input frames, U is the maximum number of output labels and V is
the vocabulary of labels (including the blank).
labels (torch.IntTensor): Tensor with shape (N, U-1) representing the
reference labels for all samples in the minibatch.
frames_lengths (torch.IntTensor): Tensor with shape (N,) representing the
number of frames for each sample in the minibatch.
labels_lengths (torch.IntTensor): Tensor with shape (N,) representing the
length of the transcription for each sample in the minibatch.
average_frames (bool, optional): Specifies whether the loss of each
sample should be divided by its number of frames.
Default: False.
reduction (string, optional): Specifies the type of reduction.
Default: None.
blank (int, optional): label used to represent the blank symbol.
Default: 0.
gather (bool, optional): Reduce memory consumption.
Default: False.
fastemit_lambda (float, optional): FastEmit regularization
(https://arxiv.org/abs/2010.11148).
Default: 0.0.
compact (bool, optional): Use compact layout, if True, shapes of inputs should be:
log_probs: (STU, V)
labels: (SU, )
where STU = sum(frames_lengths * (labels_lengths+1))
SU = sum(labels_lengths)
"""
Requirements
- C++11 or C++14 compiler (tested with GCC 5.4).
- Python: 3.5, 3.6, 3.7 (tested with version 3.6).
- PyTorch >= 1.0.0 (tested with version 1.1.0).
- CUDA Toolkit (tested with version 10.0).
Install
The following setup instructions compile the package from the source code locally.
From Pypi
pip install warp_rnnt
From GitHub
git clone https://github.com/1ytic/warp-rnnt
cd warp-rnnt/pytorch_binding
python setup.py install
Test
There is a unittest which includes tests for arguments and outputs as well.
python -m warp_rnnt.test
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
warp_rnnt-0.7.0.tar.gz
(15.6 kB
view hashes)