A fast parallel implementation of continuous integrate-and-fire (CIF) https://arxiv.org/abs/1905.11235
Project description
torch-cif
A fast parallel implementation pure PyTorch implementation of "CIF: Continuous Integrate-and-Fire for End-to-End Speech Recognition" https://arxiv.org/abs/1905.11235.
Installation
PyPI
pip install torch-cif
Locally
git clone https://github.com/George0828Zhang/torch_cif
cd torch_cif
python setup.py install
Usage
def cif_function(
inputs: Tensor,
alpha: Tensor,
beta: float = 1.0,
tail_thres: float = 0.5,
padding_mask: Optional[Tensor] = None,
target_lengths: Optional[Tensor] = None,
eps: float = 1e-4,
unbound_alpha: bool = False
) -> Dict[str, List[Tensor]]:
r""" A fast parallel implementation of continuous integrate-and-fire (CIF)
https://arxiv.org/abs/1905.11235
Shapes:
N: batch size
S: source (encoder) sequence length
C: source feature dimension
T: target sequence length
Args:
inputs (Tensor): (N, S, C) Input features to be integrated.
alpha (Tensor): (N, S) Weights corresponding to each elements in the
inputs. It is expected to be after sigmoid function.
beta (float): the threshold used for determine firing.
tail_thres (float): the threshold for determine firing for tail handling.
padding_mask (Tensor, optional): (N, S) A binary mask representing
padded elements in the inputs. 1 is padding, 0 is not.
target_lengths (Tensor, optional): (N,) Desired length of the targets
for each sample in the minibatch.
eps (float, optional): Epsilon to prevent underflow for divisions.
Default: 1e-4
unbound_alpha (bool, optional): Whether to check if 0 <= alpha <= 1.
Returns -> Dict[str, List[Tensor]]: Key/values described below.
cif_out: (N, T, C) The output integrated from the source.
cif_lengths: (N,) The output length for each element in batch.
alpha_sum: (N,) The sum of alpha for each element in batch.
Can be used to compute the quantity loss.
delays: (N, T) The expected delay (in terms of source tokens) for
each target tokens in the batch.
tail_weights: (N,) During inference, return the tail.
scaled_alpha: (N, S) alpha after applying weight scaling.
cumsum_alpha: (N, S) cumsum of alpha after scaling.
right_indices: (N, S) right scatter indices, or floor(cumsum(alpha)).
right_weights: (N, S) right scatter weights.
left_indices: (N, S) left scatter indices.
left_weights: (N, S) left scatter weights.
"""
Note
- This implementation uses
cumsum
andfloor
to determine the firing positions, and usescatter
to merge the weighted source features. The figure below demonstrates this concept using scaled weight sequence(0.4, 1.8, 1.2, 1.2, 1.4)
- Runing test requires
pip install hypothesis expecttest
. - If
beta != 1
, our implementation slightly differ from Algorithm 1 in the paper [1]:- When a boundary is located, the original algorithm add the last feature to the current integration with weight
1 - accumulation
(line 11 in Algorithm 1), which causes negative weights in next integration whenalpha < 1 - accumulation
. - We use
beta - accumulation
, which means the weight in next integrationalpha - (beta - accumulation)
is always positive.
- When a boundary is located, the original algorithm add the last feature to the current integration with weight
- Feel free to contact me if there are bugs in the code.
References
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torch_cif-0.2.0.tar.gz
(8.0 kB
view details)
Built Distribution
File details
Details for the file torch_cif-0.2.0.tar.gz
.
File metadata
- Download URL: torch_cif-0.2.0.tar.gz
- Upload date:
- Size: 8.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.18
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d865465dffb940840f82ff3db381747cdac63438c725fe64c649a3e24b5829d3 |
|
MD5 | cbd2377f35fddabe9375dfe094fe9ddc |
|
BLAKE2b-256 | 68db512246d1f48fe3132ddbf4f7a1cc89f2a35bb556b4b8bc69542d48e004c4 |
File details
Details for the file torch_cif-0.2.0-py3-none-any.whl
.
File metadata
- Download URL: torch_cif-0.2.0-py3-none-any.whl
- Upload date:
- Size: 6.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.18
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b9027a4411cc46d8d66e81fed8d3041dc3e466b560b47d22c798885d7bcc4dfc |
|
MD5 | 53c6936c1ba904aa95592c01cad07762 |
|
BLAKE2b-256 | 506256b63836a97fdd830010daceb6f512d0fd0402eb7f4b8499e1ca7da6ca70 |