Skip to main content

No project description provided

Project description

fast-perceiver

Fast and memory efficient PyTorch implementation of the Perceiver [1, 2, 3] attention architecture with FlashAttention [4, 5].

Features:

  • ⚡ More than 2x speedup over naive implementation
  • ⚡ Sub-linear1 memory usage with respect to input sequence length and linear usage with respect to number of latent vectors
  • ⚡ Out-of-the-box support for rotary positional embeddings [6]
  • ⚡ Uses new and improved FlashAttention-2 implementation
  • ⚡ Supports arbitrary masking of inputs

1 For the attention components. See Performance for more information.

Installation

pip install fast-perceiver

FlashAttention will be installed as part of the dependencies and may have to be compile locally. Note that this may take a while and fail for unsupported GPUs or CUDA versions. Please refer to the linked repository for further information and help with the installation.

Usage

import torch

from fast_perceiver import Perceiver

in_dim = 256
out_dim = 128
seq_len = 128

latent_dim = 512
num_latents = 512

model = Perceiver(
    input_dim=in_dim,
    depth=8,
    out_dim=out_dim,
    num_latents=num_latents,
    latent_dim=latent_dim,
    cross_heads=1,
    cross_head_dim=64,
    cross_rotary_emb_dim=0,
    cross_attn_dropout=0.0,
    latent_heads=8,
    latent_head_dim=64,
    latent_rotary_emb_dim=0,
    latent_attn_dropout=0.0,
    weight_tie_layers=False,
    gated_mlp=True,
)

# Note: FlashAttention only supports half-precision
# We need to explicitly cast the model or alternative use torch.autocast
model.to('cuda', torch.float16)

x = torch.randn(32, seq_len, in_dim, dtype=torch.float16, device='cuda')

seq_lens = torch.randint(1, seq_len + 1, (32,), device=x.device)
mask = torch.arange(seq_len, device=x.device)[None, :] < seq_lens[:, None]


# `out_dim` specified; averages and projects output
out = model(x)

assert out.shape == (32, out_dim)

# A input element-wise mask can be provided
# All non-True elements will be ignored
out = model(x, mask=mask)

# The raw final latents will be returned when `return_embeddings=True`
embeds = model(x, return_embeddings=True)

assert embeds.shape == (32, num_latents, latent_dim)

Performance

The Perceiver is already designed and intended as a attention architecture with sub-quadratic compute and memory complexity in comparison to the quadratic requirements of a vanilla Transformer.

A naive implementation will have $\mathcal{O}(nm)$ memory requirements for the cross-attention modules and $\mathcal{O}(n^2)$ complexity for the self-attention or latent blocks, where $n$ is the number of latent vectors (fixed hyperparameter), $m$ the number of input elements and $m \gg n$ should generally apply.

FlashAttention allows a memory usage reduction to $\mathcal{O}(n)$ for the cross-attention layers and $\mathcal{O}(n)$ for the self-attention layers. However, this only accounts for the computation of the attention mechanisms. The input sequence and corresponding keys and values within the cross-attention modules will still grow with $m$.

Until the latter starts to dominate memory usage, this implementation allows to greatly scale the input sequence length. For instance, 16x larger input lengths can be achieved in comparison to perceiver-pytorch on a RTX 4090, keeping the other hyperparameters fixed (see run_benchmarks.py for the exact configuration).

Benchmarks

Benchmarks against other implementations (currently only perceiver-pytorch) can be performed with:

python run_benchmarks.py

The script will create a benchmark_results.csv. The create_plots.py script can then be used to create plots.

The following plots have been created using a single RTX 4090 with 24GB of VRAM.

Benchmark results on speedup

Benchmark results on memory usage reduction

Note: The batch size for each configuration corresponds to the smallest value that works for all implementations. Especially for longer sequence lengths, this leads to decreasing GPU utilization and thus a lower speedup than theoretically possible. There are some ways to fix this, but my attempts so far have led to distorted results.

Acknowledgements

The implementation is inspired by lucidrain's Perceiver implementation and would not have been possible without Tri Dao's FlashAttention.

Planned features

These are a few features that are either planned or WIP. If you have urgent demand for some of them, feel free to write an issue:

  • Perceiver IO [2] [WIP]
  • Perceiver AR [3] (or an AR demo in general)
  • Demos [WIP]
  • Tests [WIP]
  • Allow more flexible cross-attention configurations
  • Benchmarks against other Perceiver implementations, e.g. DeepMind's or Krasser's
  • If FA2 is eventuelly merged into PyTorch, drop the flash-attn dependency

References

[1] Jaegle, Andrew, Felix Gimeno, Andrew Brock, Andrew Zisserman, Oriol Vinyals, and Joao Carreira. “Perceiver: General Perception with Iterative Attention.” arXiv, June 22, 2021. http://arxiv.org/abs/2103.03206.

[2] Jaegle, Andrew, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch, Catalin Ionescu, David Ding, Skanda Koppula, et al. “Perceiver IO: A General Architecture for Structured Inputs & Outputs.” arXiv, March 15, 2022. http://arxiv.org/abs/2107.14795.

[3] Hawthorne, Curtis, Andrew Jaegle, Cătălina Cangea, Sebastian Borgeaud, Charlie Nash, Mateusz Malinowski, Sander Dieleman, et al. “General-Purpose, Long-Context Autoregressive Modeling with Perceiver AR.” arXiv, June 14, 2022. http://arxiv.org/abs/2202.07765.

[4] Dao, Tri, Daniel Y. Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” arXiv, June 23, 2022. https://doi.org/10.48550/arXiv.2205.14135.

[5] Dao, Tri. “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning.” arXiv, July 17, 2023. https://doi.org/10.48550/arXiv.2307.08691.

[6] Su, Jianlin, Yu Lu, Shengfeng Pan, Ahmed Murtadha, Bo Wen, and Yunfeng Liu. “RoFormer: Enhanced Transformer with Rotary Position Embedding.” arXiv, August 8, 2022. https://doi.org/10.48550/arXiv.2104.09864.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fast_perceiver-0.1.3.tar.gz (6.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fast_perceiver-0.1.3-py3-none-any.whl (7.0 kB view details)

Uploaded Python 3

File details

Details for the file fast_perceiver-0.1.3.tar.gz.

File metadata

  • Download URL: fast_perceiver-0.1.3.tar.gz
  • Upload date:
  • Size: 6.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.1 CPython/3.10.9 Linux/5.15.0-76-generic

File hashes

Hashes for fast_perceiver-0.1.3.tar.gz
Algorithm Hash digest
SHA256 293b310c84457c0d5e772d833365a657630e2de4c091f179c6626a1724c341dc
MD5 79840f4485805f92b3bd7ce106681e5a
BLAKE2b-256 6ae3cd7cb3e8227349789ab1f09bf9dcb5b514c01cb1c5086f1f744723e2b8ac

See more details on using hashes here.

File details

Details for the file fast_perceiver-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: fast_perceiver-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 7.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.1 CPython/3.10.9 Linux/5.15.0-76-generic

File hashes

Hashes for fast_perceiver-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 99da48c0338fb60083d94f885e8fb5bcf0aea23cec997ab4a0413144494b20bc
MD5 81e6da65f8ca76209299daa0ac740d7c
BLAKE2b-256 5df0e3076b43147589a63b0075db1a03b7311165f612e48570c42da316122f7e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page