PyTorch bindings for CUDA-Warp Recurrent Neural Aligner
Project description
PyTorch bindings for CUDA-Warp Recurrent Neural Aligner
def rna_loss(log_probs: torch.FloatTensor,
labels: torch.IntTensor,
frames_lengths: torch.IntTensor,
labels_lengths: torch.IntTensor,
average_frames: bool = False,
reduction: Optional[AnyStr] = None,
blank: int = 0) -> torch.Tensor:
"""The CUDA-Warp Recurrent Neural Aligner loss.
Args:
log_probs (torch.FloatTensor): Input tensor with shape (N, T, U, V)
where N is the minibatch size, T is the maximum number of
input frames, U is the maximum number of output labels and V is
the vocabulary of labels (including the blank).
labels (torch.IntTensor): Tensor with shape (N, U-1) representing the
reference labels for all samples in the minibatch.
frames_lengths (torch.IntTensor): Tensor with shape (N,) representing the
number of frames for each sample in the minibatch.
labels_lengths (torch.IntTensor): Tensor with shape (N,) representing the
length of the transcription for each sample in the minibatch.
average_frames (bool, optional): Specifies whether the loss of each
sample should be divided by its number of frames.
Default: False.
reduction (string, optional): Specifies the type of reduction.
Default: None.
blank (int, optional): label used to represent the blank symbol.
Default: 0.
"""
Requirements
- C++11 or C++14 compiler (tested with GCC 5.4).
- Python: 3.5, 3.6, 3.7 (tested with version 3.6).
- PyTorch >= 1.0.0 (tested with version 1.1.0).
- CUDA Toolkit (tested with version 10.0).
Install
The following setup instructions compile the package from the source code locally.
From Pypi
pip install warp_rna
From GitHub
git clone https://github.com/1ytic/warp-rna
cd warp-rna/pytorch_binding
python setup.py install
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
warp_rna-0.3.0.tar.gz
(8.4 kB
view details)
File details
Details for the file warp_rna-0.3.0.tar.gz
.
File metadata
- Download URL: warp_rna-0.3.0.tar.gz
- Upload date:
- Size: 8.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.21.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.6.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 79e9db3c46b21633f08e917c3af3b422541b2657f22397f90fa43b3753695df2 |
|
MD5 | e78eafdded12592397cc327965586c6b |
|
BLAKE2b-256 | 99b1ddec69cddadf990c276536da586b773191b8de49be79630a9e17d61aa2e9 |