Skip to main content

A fast parallel implementation of continuous integrate-and-fire (CIF) https://arxiv.org/abs/1905.11235

Project description

torch-cif

A fast parallel implementation pure PyTorch implementation of "CIF: Continuous Integrate-and-Fire for End-to-End Speech Recognition" https://arxiv.org/abs/1905.11235.

Installation

PyPI

pip install torch-cif

Locally

git clone https://github.com/George0828Zhang/torch_cif
cd torch_cif
python setup.py install

Usage

def cif_function(
    inputs: Tensor,
    alpha: Tensor,
    beta: float = 1.0,
    tail_thres: float = 0.5,
    padding_mask: Optional[Tensor] = None,
    target_lengths: Optional[Tensor] = None,
    eps: float = 1e-4,
    unbound_alpha: bool = False
) -> Dict[str, List[Tensor]]:
    r""" A fast parallel implementation of continuous integrate-and-fire (CIF)
    https://arxiv.org/abs/1905.11235

    Shapes:
        N: batch size
        S: source (encoder) sequence length
        C: source feature dimension
        T: target sequence length

    Args:
        inputs (Tensor): (N, S, C) Input features to be integrated.
        alpha (Tensor): (N, S) Weights corresponding to each elements in the
            inputs. It is expected to be after sigmoid function.
        beta (float): the threshold used for determine firing.
        tail_thres (float): the threshold for determine firing for tail handling.
        padding_mask (Tensor, optional): (N, S) A binary mask representing
            padded elements in the inputs. 1 is padding, 0 is not.
        target_lengths (Tensor, optional): (N,) Desired length of the targets
            for each sample in the minibatch.
        eps (float, optional): Epsilon to prevent underflow for divisions.
            Default: 1e-4
        unbound_alpha (bool, optional): Whether to check if 0 <= alpha <= 1.

    Returns -> Dict[str, List[Tensor]]: Key/values described below.
        cif_out: (N, T, C) The output integrated from the source.
        cif_lengths: (N,) The output length for each element in batch.
        alpha_sum: (N,) The sum of alpha for each element in batch.
            Can be used to compute the quantity loss.
        delays: (N, T) The expected delay (in terms of source tokens) for
            each target tokens in the batch.
        tail_weights: (N,) During inference, return the tail.
        scaled_alpha: (N, S) alpha after applying weight scaling.
        cumsum_alpha: (N, S) cumsum of alpha after scaling.
        right_indices: (N, S) right scatter indices, or floor(cumsum(alpha)).
        right_weights: (N, S) right scatter weights.
        left_indices: (N, S) left scatter indices.
        left_weights: (N, S) left scatter weights.
    """

Note

  • This implementation uses cumsum and floor to determine the firing positions, and use scatter to merge the weighted source features. The figure below demonstrates this concept using scaled weight sequence (0.4, 1.8, 1.2, 1.2, 1.4)
drawing
  • Runing test requires pip install hypothesis expecttest.
  • If beta != 1, our implementation slightly differ from Algorithm 1 in the paper [1]:
    • When a boundary is located, the original algorithm add the last feature to the current integration with weight 1 - accumulation (line 11 in Algorithm 1), which causes negative weights in next integration when alpha < 1 - accumulation.
    • We use beta - accumulation, which means the weight in next integration alpha - (beta - accumulation) is always positive.
  • Feel free to contact me if there are bugs in the code.

References

  1. CIF: Continuous Integrate-and-Fire for End-to-End Speech Recognition
  2. Exploring Continuous Integrate-and-Fire for Adaptive Simultaneous Speech Translation

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_cif-0.2.0.tar.gz (8.0 kB view details)

Uploaded Source

Built Distribution

torch_cif-0.2.0-py3-none-any.whl (6.6 kB view details)

Uploaded Python 3

File details

Details for the file torch_cif-0.2.0.tar.gz.

File metadata

  • Download URL: torch_cif-0.2.0.tar.gz
  • Upload date:
  • Size: 8.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.18

File hashes

Hashes for torch_cif-0.2.0.tar.gz
Algorithm Hash digest
SHA256 d865465dffb940840f82ff3db381747cdac63438c725fe64c649a3e24b5829d3
MD5 cbd2377f35fddabe9375dfe094fe9ddc
BLAKE2b-256 68db512246d1f48fe3132ddbf4f7a1cc89f2a35bb556b4b8bc69542d48e004c4

See more details on using hashes here.

File details

Details for the file torch_cif-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: torch_cif-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 6.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.18

File hashes

Hashes for torch_cif-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b9027a4411cc46d8d66e81fed8d3041dc3e466b560b47d22c798885d7bcc4dfc
MD5 53c6936c1ba904aa95592c01cad07762
BLAKE2b-256 506256b63836a97fdd830010daceb6f512d0fd0402eb7f4b8499e1ca7da6ca70

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page