Skip to main content

No project description provided

Project description

fast-perceiver

Fast and memory efficient PyTorch implementation of the Perceiver [1, 2, 3] attention architecture with FlashAttention [4, 5].

Features:

  • ⚡ More than 2x speedup over naive implementation
  • ⚡ Sub-linear1 memory usage with respect to input sequence length and linear usage with respect to number of latent vectors
  • ⚡ Out-of-the-box support for rotary positional embeddings [6]
  • ⚡ Uses new and improved FlashAttention-2 implementation
  • ⚡ Supports arbitrary masking of inputs

1 For the attention components. See Performance for more information.

Installation

pip install fast-perceiver

FlashAttention will be installed as part of the dependencies and may have to be compile locally. Note that this may take a while and fail for unsupported GPUs or CUDA versions. Please refer to the linked repository for further information and help with the installation.

Usage

import torch

from fast_perceiver import Perceiver

in_dim = 256
out_dim = 128

latent_dim = 512
num_latents = 512

model = Perceiver(
    input_dim=in_dim,
    depth=8,
    out_dim=out_dim,
    num_latents=num_latents,
    latent_dim=latent_dim,
    cross_heads=1,
    cross_head_dim=64,
    latent_heads=8,
    latent_head_dim=64,
    cross_attn_dropout=0.0,
    self_attn_dropout=0.0,
    weight_tie_layers=False,
    gated_mlp=True,
).cuda()

x = torch.randn(32, 128, in_dim).cuda()
mask = torch.rand(32, 128) > 0.5

# FlashAttention only works with half-precision
# Don't forget to autocast!
with torch.autocast('cuda'):
    # `out_dim` specified; averages and projects output
    out = model(x)

    assert out.shape == (32, out_dim)

    # A input element-wise mask can be provided
    # All non-True elements will be ignored
    out = model(x, mask=mask)

    # The raw final latents will be returned when `return_embeddings=True`
    embeds = model(x, return_embeddings=True)

    assert embeds.shape == (32, num_latents, latent_dim)

Performance

The Perceiver is already designed and intended as a attention architecture with sub-quadratic compute and memory complexity in comparison to the quadratic requirements of a vanilla Transformer.

A naive implementation will have $\mathcal{O}(nm)$ memory requirements for the cross-attention modules and $\mathcal{O}(n^2)$ complexity for the self-attention or latent blocks, where $n$ is the number of latent vectors (fixed hyperparameter), $m$ the number of input elements and $m \gg n$ should generally apply.

FlashAttention allows a memory usage reduction to $\mathcal{O}(n)$ for the cross-attention layers and $\mathcal{O}(n)$ for the self-attention layers. However, this only accounts for the computation of the attention mechanisms. The input sequence and corresponding keys and values within the cross-attention modules will still grow with $m$.

Until the latter starts to dominate memory usage, this implementation allows to greatly scale the input sequence length. For instance, 16x larger input lengths can be achieved in comparison to perceiver-pytorch on a RTX 4090, keeping the other hyperparameters fixed (see run_benchmarks.py for the exact configuration).

Benchmarks

Benchmarks against other implementations (currently only perceiver-pytorch) can be performed with:

python run_benchmarks.py

The script will create a benchmark_results.csv. The create_plots.py script can then be used to create plots.

The following plots have been created using a single RTX 4090 with 24GB of VRAM.

Benchmark results on speedup

Benchmark results on memory usage reduction

Note: The batch size for each configuration corresponds to the smallest value that works for all implementations. Especially for longer sequence lengths, this leads to decreasing GPU utilization and thus a lower speedup than theoretically possible. There are some ways to fix this, but my attempts so far have led to distorted results.

Acknowledgements

The implementation is inspired by lucidrain's Perceiver implementation and would not have been possible without Tri Dao's FlashAttention.

Planned features

These are a few features that are either planned or WIP. If you have urgent demand for some of them, feel free to write an issue:

  • Perceiver IO [2] [WIP]
  • Perceiver AR [3] (or an AR demo in general)
  • Demos [WIP]
  • Tests [WIP]
  • Allow more flexible cross-attention configurations
  • Benchmarks against other Perceiver implementations, e.g. DeepMind's or Krasser's

References

[1] Jaegle, Andrew, Felix Gimeno, Andrew Brock, Andrew Zisserman, Oriol Vinyals, and Joao Carreira. “Perceiver: General Perception with Iterative Attention.” arXiv, June 22, 2021. http://arxiv.org/abs/2103.03206.

[2] Jaegle, Andrew, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch, Catalin Ionescu, David Ding, Skanda Koppula, et al. “Perceiver IO: A General Architecture for Structured Inputs & Outputs.” arXiv, March 15, 2022. http://arxiv.org/abs/2107.14795.

[3] Hawthorne, Curtis, Andrew Jaegle, Cătălina Cangea, Sebastian Borgeaud, Charlie Nash, Mateusz Malinowski, Sander Dieleman, et al. “General-Purpose, Long-Context Autoregressive Modeling with Perceiver AR.” arXiv, June 14, 2022. http://arxiv.org/abs/2202.07765.

[4] Dao, Tri, Daniel Y. Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” arXiv, June 23, 2022. https://doi.org/10.48550/arXiv.2205.14135.

[5] Dao, Tri. “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning.” arXiv, July 17, 2023. https://doi.org/10.48550/arXiv.2307.08691.

[6] Su, Jianlin, Yu Lu, Shengfeng Pan, Ahmed Murtadha, Bo Wen, and Yunfeng Liu. “RoFormer: Enhanced Transformer with Rotary Position Embedding.” arXiv, August 8, 2022. https://doi.org/10.48550/arXiv.2104.09864.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fast_perceiver-0.1.2.tar.gz (5.8 kB view details)

Uploaded Source

Built Distribution

fast_perceiver-0.1.2-py3-none-any.whl (6.4 kB view details)

Uploaded Python 3

File details

Details for the file fast_perceiver-0.1.2.tar.gz.

File metadata

  • Download URL: fast_perceiver-0.1.2.tar.gz
  • Upload date:
  • Size: 5.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.1 CPython/3.10.9 Linux/5.15.0-76-generic

File hashes

Hashes for fast_perceiver-0.1.2.tar.gz
Algorithm Hash digest
SHA256 5de25558b4bae431d68a35f56a81439669db41fc4a8e4ff42be155f860503f64
MD5 ff49593405594834b24ea982c80ecc0b
BLAKE2b-256 701ceec4c5b968be7970ed40f74d15d889486478a2e5d24514407761249a4954

See more details on using hashes here.

File details

Details for the file fast_perceiver-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: fast_perceiver-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 6.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.1 CPython/3.10.9 Linux/5.15.0-76-generic

File hashes

Hashes for fast_perceiver-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 1bce471958828f27adc758e7843b8fc9b4096123db47fd217a2e472d632a3824
MD5 37033590bc179203c4e33b4d1646b035
BLAKE2b-256 fc93cb792ada5db87c1dd3b30125889c561ab0b475246484f803c61dc2a9a6ff

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page