A fast parallel implementation of continuous integrate-and-fire (CIF) https://arxiv.org/abs/1905.11235
Project description
torch-cif
A fast parallel implementation pure PyTorch implementation of "CIF: Continuous Integrate-and-Fire for End-to-End Speech Recognition" https://arxiv.org/abs/1905.11235.
Installation
PyPI
pip install torch-cif
Locally
git clone https://github.com/George0828Zhang/torch_cif
cd torch_cif
python setup.py install
Usage
def cif_function(
inputs: Tensor,
alpha: Tensor,
beta: float = 1.0,
tail_thres: float = 0.5,
padding_mask: Optional[Tensor] = None,
target_lengths: Optional[Tensor] = None,
eps: float = 1e-4,
unbound_alpha: bool = False
) -> Dict[str, List[Tensor]]:
r""" A fast parallel implementation of continuous integrate-and-fire (CIF)
https://arxiv.org/abs/1905.11235
Shapes:
N: batch size
S: source (encoder) sequence length
C: source feature dimension
T: target sequence length
Args:
inputs (Tensor): (N, S, C) Input features to be integrated.
alpha (Tensor): (N, S) Weights corresponding to each elements in the
inputs. It is expected to be after sigmoid function.
beta (float): the threshold used for determine firing.
tail_thres (float): the threshold for determine firing for tail handling.
padding_mask (Tensor, optional): (N, S) A binary mask representing
padded elements in the inputs. 1 is padding, 0 is not.
target_lengths (Tensor, optional): (N,) Desired length of the targets
for each sample in the minibatch.
eps (float, optional): Epsilon to prevent underflow for divisions.
Default: 1e-4
unbound_alpha (bool, optional): Whether to check if 0 <= alpha <= 1.
Returns -> Dict[str, List[Tensor]]: Key/values described below.
cif_out: (N, T, C) The output integrated from the source.
cif_lengths: (N,) The output length for each element in batch.
alpha_sum: (N,) The sum of alpha for each element in batch.
Can be used to compute the quantity loss.
delays: (N, T) The expected delay (in terms of source tokens) for
each target tokens in the batch.
tail_weights: (N,) During inference, return the tail.
scaled_alpha: (N, S) alpha after applying weight scaling.
cumsum_alpha: (N, S) cumsum of alpha after scaling.
right_indices: (N, S) right scatter indices, or floor(cumsum(alpha)).
right_weights: (N, S) right scatter weights.
left_indices: (N, S) left scatter indices.
left_weights: (N, S) left scatter weights.
"""
Note
- This implementation uses
cumsumandfloorto determine the firing positions, and usescatterto merge the weighted source features. The figure below demonstrates this concept using scaled weight sequence(0.4, 1.8, 1.2, 1.2, 1.4)
- Runing test requires
pip install hypothesis expecttest. - If
beta != 1, our implementation slightly differ from Algorithm 1 in the paper [1]:- When a boundary is located, the original algorithm add the last feature to the current integration with weight
1 - accumulation(line 11 in Algorithm 1), which causes negative weights in next integration whenalpha < 1 - accumulation. - We use
beta - accumulation, which means the weight in next integrationalpha - (beta - accumulation)is always positive.
- When a boundary is located, the original algorithm add the last feature to the current integration with weight
- Feel free to contact me if there are bugs in the code.
References
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torch_cif-0.2.0.tar.gz
(8.0 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_cif-0.2.0.tar.gz.
File metadata
- Download URL: torch_cif-0.2.0.tar.gz
- Upload date:
- Size: 8.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d865465dffb940840f82ff3db381747cdac63438c725fe64c649a3e24b5829d3
|
|
| MD5 |
cbd2377f35fddabe9375dfe094fe9ddc
|
|
| BLAKE2b-256 |
68db512246d1f48fe3132ddbf4f7a1cc89f2a35bb556b4b8bc69542d48e004c4
|
File details
Details for the file torch_cif-0.2.0-py3-none-any.whl.
File metadata
- Download URL: torch_cif-0.2.0-py3-none-any.whl
- Upload date:
- Size: 6.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b9027a4411cc46d8d66e81fed8d3041dc3e466b560b47d22c798885d7bcc4dfc
|
|
| MD5 |
53c6936c1ba904aa95592c01cad07762
|
|
| BLAKE2b-256 |
506256b63836a97fdd830010daceb6f512d0fd0402eb7f4b8499e1ca7da6ca70
|