Skip to main content
Join the official 2019 Python Developers SurveyStart the survey!

PyTorch bindings for CUDA-Warp Recurrent Neural Aligner

Project description

PyTorch bindings for CUDA-Warp Recurrent Neural Aligner

def rna_loss(
        log_probs,  # type: torch.FloatTensor
        labels,  # type: torch.IntTensor
        frames_lengths,  # type: torch.IntTensor
        labels_lengths,  # type: torch.IntTensor
        average_frames=False,  # type: bool
        reduction=None,  # type: Optional[AnyStr]
        blank=0,  # type: int
):
    """The CUDA-Warp Recurrent Neural Aligner loss.

    Args:
      log_probs (torch.Tensor): Input tensor (float) with shape
        (T, N, U, V) where T is the maximum number of input frames, N is the
        minibatch size, U is the maximum number of output labels and V is
        the vocabulary of labels (including the blank).
      labels (torch.IntTensor): Tensor with shape (N, U-1) representing the
        reference labels for all samples in the minibatch.
      frames_lengths (torch.IntTensor): Tensor with shape (N,) representing the
        number of frames for each sample in the minibatch.
      labels_lengths (torch.IntTensor): Tensor with shape (N,) representing the
        length of the transcription for each sample in the minibatch.
      average_frames (bool, optional): Specifies whether the loss of each
        sample should be divided by its number of frames. Default: ``False''.
      reduction (string, optional): Specifies the type of reduction.
        Default: None.
      blank (int, optional): label used to represent the blank symbol.
        Default: 0.
    """
    # type: (...) -> torch.Tensor

Requirements

  • C++11 compiler (tested with GCC 5.4).
  • Python: 3.5, 3.6, 3.7 (tested with version 3.6).
  • PyTorch >= 1.0.0 (tested with version 1.1.0).
  • CUDA Toolkit (tested with version 10.0).

Install

Currently, there is no compiled version of the package. The following setup instructions compile the package from the source code locally.

From Pypi

pip install warp_rna

From GitHub

git clone https://github.com/1ytic/warp-rna
cd warp-rna/pytorch_binding
python setup.py install

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for warp-rna, version 0.1.0
Filename, size File type Python version Upload date Hashes
Filename, size warp_rna-0.1.0.tar.gz (9.2 kB) File type Source Python version None Upload date Hashes View hashes

Supported by

Elastic Elastic Search Pingdom Pingdom Monitoring Google Google BigQuery Sentry Sentry Error logging AWS AWS Cloud computing DataDog DataDog Monitoring Fastly Fastly CDN SignalFx SignalFx Supporter DigiCert DigiCert EV certificate StatusPage StatusPage Status page