Warp RNNT loss ported to Numba for faster experimentation
Project description
RNNT loss in Pytorch - Numba JIT compiled (warprnnt_numba)
Warp RNN Transducer Loss for ASR in Pytorch, ported from HawkAaron/warp-transducer and a replica of the stable version in NVIDIA Neural Module repository (NVIDIA NeMo).
NOTE: The code here will have experimental extensions and may be potentially unstable, use the version in NeMo for long term supported loss version of RNNT for PyTorch.
Supported Features
Currently supports :
- WarpRNNT loss in pytorch for CPU / CUDA (jit compiled)
- FastEmit
- Gradient Clipping (from Torch Audio)
Installation
You will need PyTorch (usually the latest version should be used), plus installation of Numba in a Conda environment (pip only environment is untested but may work).
# Follow installation instructions to install pytorch from website (with cuda if required)
conda install -c conda-force numba or conda update -c conda-forge numba (to get latest version)
# Then install this library
pip install --upgrade git+https://github.com/titu1994/warprnnt_numba.git
Usage
Import warprnnt_numba
and use RNNTLossNumba
. If attempting to use CUDA version of loss, it is advisable to test that your installed CUDA version is compatible with numba version using numba_utils
.
There is also included a very slow numpy/pytorch explicit-loop based loss implementation for verification of exact correct results.
import torch
import numpy as np
import warprnnt_numba
# Define the loss function
fastemit_lambda = 0.001 # any float >= 0.0
loss_pt = warprnnt_numba.RNNTLossNumba(blank=4, reduction='sum', fastemit_lambda=fastemit_lambda)
# --------------
# Example usage
device = "cuda"
torch.random.manual_seed(0)
# Assume Batchsize=2, Acoustic Timesteps = 8, Label Timesteps = 5 (including BLANK=BOS token),
# and Vocabulary size of 5 tokens (including RNNT BLANK)
acts = torch.randn(2, 8, 5, 5, device=device, requires_grad=True)
sequence_length = torch.tensor([5, 8], dtype=torch.int32,
device=device) # acoustic sequence length. One element must be == acts.shape[1].
# Let 0 be MASK/PAD value, 1-3 be token ids, and 4 represent RNNT BLANK token
# The BLANK token is overloaded for BOS token as well here, but can be different token.
# Let first sample be padded with 0 (actual length = 3). Loss is computed according to supplied `label_lengths`.
# and gradients for the 4th index onwards (0 based indexing).
labels = torch.tensor([[4, 1, 1, 3, 0], [4, 2, 2, 3, 1]], dtype=torch.int32, device=device)
label_lengths = torch.tensor([3, 4], dtype=torch.int32,
device=device) # Lengths here must be WITHOUT the BOS token.
# If on CUDA, log_softmax is computed internally efficiently (preserving memory and speed)
# Compute it explicitly for CPU, this is done automatically for you inside forward() of the loss.
# -1-th vocab index is RNNT blank token here.
loss_func = warprnnt_numba.RNNTLossNumba(blank=4, reduction='none',
fastemit_lambda=0.0, clamp=0.0)
loss = loss_func(acts, labels, sequence_length, label_lengths)
print("Loss :", loss)
loss.sum().backward()
# When parsing the gradients, look at grads[0] -
# Since it was padded in T (sequence_length=5 < T=8), there are gradients only for grads[0, :5, :, :].
# Since it was padded in U (label_lengths=3+1 < U=5), there are gradeints only for grads[0, :5, :3+1, :].
grads = acts.grad
print("Gradients of activations :")
print(grads)
Tests
Tests will perform CPU only checks if there are no GPUs. If GPUs are present, will run all tests once for cuda:0
as well.
pytest tests/
Requirements
- pytorch >= 1.10. Older versions might work, not tested.
- numba - Minimum required version is 0.53.0, preferred is 0.54+.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for warprnnt_numba-0.2.3-py2.py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9336e183ff4e74183be7166b37d36602d6d2a23b92e2356a4817e8d5ea207c77 |
|
MD5 | b3ecf39345f12b07ea074f6a2fd8db8e |
|
BLAKE2b-256 | e1c8d475c5fbe513ecd2d58d52e5e2e38f3e4244996fe5cff993a7290869c367 |