Warp RNNT loss ported to Numba for faster experimentation
Project description
RNNT loss in Pytorch - Numba JIT compiled (warprnnt_numba)
Warp RNN Transducer Loss for ASR in Pytorch, ported from HawkAaron/warp-transducer and a replica of the stable version in NVIDIA Neural Module repository (NVIDIA NeMo).
NOTE: The code here will have experimental extensions and may be potentially unstable, use the version in NeMo for long term supported loss version of RNNT for PyTorch.
Supported Features
Currently supports :
- WarpRNNT loss in pytorch for CPU / CUDA (jit compiled)
- FastEmit
Installation
You will need PyTorch (usually the latest version should be used), plus installation of Numba in a Conda environment (pip only environment is untested but may work).
# Follow installation instructions to install pytorch from website (with cuda if required)
conda install -c conda-force numba or conda update -c conda-forge numba (to get latest version)
# Then install this library
pip install --upgrade git+https://github.com/titu1994/warprnnt_numba.git
Usage
Import warprnnt_numba
and use RNNTLossNumba
. If attempting to use CUDA version of loss, it is advisable to test that your installed CUDA version is compatible with numba version using numba_utils
.
There is also included a very slow numpy/pytorch explicit-loop based loss implementation for verification of exact correct results.
import torch
import numpy as np
import warprnnt_numba
# Define the loss function
fastemit_lambda = 0.001 # any float >= 0.0
loss_pt = warprnnt_numba.RNNTLossNumba(blank=4, reduction='sum', fastemit_lambda=fastemit_lambda)
# --------------
# Example usage
device = "cuda"
torch.random.manual_seed(0)
# Assume Batchsize=2, Acoustic Timesteps = 8, Label Timesteps = 5 (including EOS token),
# and Vocabulary size of 5 tokens (including RNNT BLANK)
acts = torch.randn(2, 8, 5, 5, device=device, requires_grad=True)
sequence_length = torch.tensor([5, 8], dtype=torch.int32, device=device) # acoustic sequence length. One element must be == acts.shape[1].
# Let 0 be MASK value, 1-3 be token ids, and 4 represent RNNT BLANK token
# The BLANK token is overloaded for EOS token as well here, but can be different token.
# Let first sample be padded with 0 (actual length = 3). Loss is computed according to supplied `label_lengths`.
# and gradients for the 4th index onwards (0 based indexing).
labels = torch.tensor([[1, 1, 3, 4, 0], [2, 2, 3, 1, 4]], dtype=torch.int32, device=device)
label_lengths = torch.tensor([3, 4], dtype=torch.int32, device=device) # Lengths here must be WITHOUT the EOS token.
# If on CUDA, log_softmax is computed internally efficiently (preserving memory and speed)
# Compute it explicitly for CPU, sacrificing some memory.
if device == 'cpu':
acts = torch.log_softmax(acts, dim=-1)
loss_func = RNNTLossNumba(blank=4, reduction='none', fastemit_lambda=0.0) # -1-th vocab index is RNNT blank token
loss = loss_func(acts, labels, sequence_length, label_lengths)
print("Loss :", loss)
loss.sum().backward()
grads = acts.grad
print("Gradients of activations :")
print(grads)
Tests
Tests will perform CPU only checks if there are no GPUs. If GPUs are present, will run all tests once for cuda:0
as well.
pytest tests/
Requirements
- pytorch >= 1.10. Older versions might work, not tested.
- numba - Minimum required version is 0.53.0, preferred is 0.54+.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Hashes for warprnnt_numba-0.1.0-py2.py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | fa721fe720356fe376b4ff493abf1eee088ddb3d6bfc9b4bbaf5102d7e5ce987 |
|
MD5 | 3541366cdc76ba5370807f4a477f8c18 |
|
BLAKE2b-256 | 9eac466596f02f1f64949fb47f08f2ef223bb9db3d7daa24275be90123a09229 |