Skip to main content

PyTorch bindings for CUDA-Warp RNN-Transducer

Project description

PyTorch bindings for CUDA-Warp RNN-Transducer

def rnnt_loss(log_probs: torch.FloatTensor,
              labels: torch.IntTensor,
              frames_lengths: torch.IntTensor,
              labels_lengths: torch.IntTensor,
              average_frames: bool = False,
              reduction: Optional[AnyStr] = None,
              blank: int = 0,
              gather: bool = False,
              fastemit_lambda: float = 0.0,
              compact: bool = False) -> torch.Tensor:

    """The CUDA-Warp RNN-Transducer loss.

        log_probs (torch.FloatTensor): Input tensor with shape (N, T, U, V)
            where N is the minibatch size, T is the maximum number of
            input frames, U is the maximum number of output labels and V is
            the vocabulary of labels (including the blank).
        labels (torch.IntTensor): Tensor with shape (N, U-1) representing the
            reference labels for all samples in the minibatch.
        frames_lengths (torch.IntTensor): Tensor with shape (N,) representing the
            number of frames for each sample in the minibatch.
        labels_lengths (torch.IntTensor): Tensor with shape (N,) representing the
            length of the transcription for each sample in the minibatch.
        average_frames (bool, optional): Specifies whether the loss of each
            sample should be divided by its number of frames.
            Default: False.
        reduction (string, optional): Specifies the type of reduction.
            Default: None.
        blank (int, optional): label used to represent the blank symbol.
            Default: 0.
        gather (bool, optional): Reduce memory consumption.
            Default: False.
        fastemit_lambda (float, optional): FastEmit regularization
            Default: 0.0.
        compact (bool, optional): Use compact layout, if True, shapes of inputs should be:
            log_probs: (STU, V)
            labels:    (SU, )
            where STU = sum(frames_lengths * (labels_lengths+1))
                  SU  = sum(labels_lengths)


  • C++11 or C++14 compiler (tested with GCC 5.4).
  • Python: 3.5, 3.6, 3.7 (tested with version 3.6).
  • PyTorch >= 1.0.0 (tested with version 1.1.0).
  • CUDA Toolkit (tested with version 10.0).


The following setup instructions compile the package from the source code locally.

From Pypi

pip install warp_rnnt

From GitHub

git clone
cd warp-rnnt/pytorch_binding
python install


There is a unittest which includes tests for arguments and outputs as well.

python -m warp_rnnt.test

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

warp_rnnt-0.7.0.tar.gz (15.6 kB view hashes)

Uploaded source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page