Skip to main content

PyTorch bindings for CUDA-Warp Recurrent Neural Aligner

Project description

PyTorch bindings for CUDA-Warp Recurrent Neural Aligner

def rna_loss(log_probs: torch.FloatTensor,
             labels: torch.IntTensor,
             frames_lengths: torch.IntTensor,
             labels_lengths: torch.IntTensor,
             average_frames: bool = False,
             reduction: Optional[AnyStr] = None,
             blank: int = 0) -> torch.Tensor:

    """The CUDA-Warp Recurrent Neural Aligner loss.

    Args:
        log_probs (torch.FloatTensor): Input tensor with shape (N, T, U, V)
            where N is the minibatch size, T is the maximum number of
            input frames, U is the maximum number of output labels and V is
            the vocabulary of labels (including the blank).
        labels (torch.IntTensor): Tensor with shape (N, U-1) representing the
            reference labels for all samples in the minibatch.
        frames_lengths (torch.IntTensor): Tensor with shape (N,) representing the
            number of frames for each sample in the minibatch.
        labels_lengths (torch.IntTensor): Tensor with shape (N,) representing the
            length of the transcription for each sample in the minibatch.
        average_frames (bool, optional): Specifies whether the loss of each
            sample should be divided by its number of frames.
            Default: False.
        reduction (string, optional): Specifies the type of reduction.
            Default: None.
        blank (int, optional): label used to represent the blank symbol.
            Default: 0.
    """

Requirements

  • C++11 or C++14 compiler (tested with GCC 5.4).
  • Python: 3.5, 3.6, 3.7 (tested with version 3.6).
  • PyTorch >= 1.0.0 (tested with version 1.1.0).
  • CUDA Toolkit (tested with version 10.0).

Install

The following setup instructions compile the package from the source code locally.

From Pypi

pip install warp_rna

From GitHub

git clone https://github.com/1ytic/warp-rna
cd warp-rna/pytorch_binding
python setup.py install

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

warp_rna-0.3.0.tar.gz (8.4 kB view details)

Uploaded Source

File details

Details for the file warp_rna-0.3.0.tar.gz.

File metadata

  • Download URL: warp_rna-0.3.0.tar.gz
  • Upload date:
  • Size: 8.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.21.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.6.8

File hashes

Hashes for warp_rna-0.3.0.tar.gz
Algorithm Hash digest
SHA256 79e9db3c46b21633f08e917c3af3b422541b2657f22397f90fa43b3753695df2
MD5 e78eafdded12592397cc327965586c6b
BLAKE2b-256 99b1ddec69cddadf990c276536da586b773191b8de49be79630a9e17d61aa2e9

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page