No project description provided
Project description
fast-perceiver
Fast and memory efficient PyTorch implementation of the Perceiver [1, 2, 3] attention architecture with FlashAttention [4, 5].
Features:
- ⚡ More than 2x speedup over naive implementation
- ⚡ Sub-linear1 memory usage with respect to input sequence length and linear usage with respect to number of latent vectors
- ⚡ Out-of-the-box support for rotary positional embeddings [6]
- ⚡ Uses new and improved FlashAttention-2 implementation
- ⚡ Supports arbitrary masking of inputs
1 For the attention components. See Performance for more information.
Installation
pip install fast-perceiver
FlashAttention will be installed as part of the dependencies and may have to be compile locally. Note that this may take a while and fail for unsupported GPUs or CUDA versions. Please refer to the linked repository for further information and help with the installation.
Usage
import torch
from fast_perceiver import Perceiver
in_dim = 256
out_dim = 128
latent_dim = 512
num_latents = 512
model = Perceiver(
input_dim=in_dim,
depth=8,
out_dim=out_dim,
num_latents=num_latents,
latent_dim=latent_dim,
cross_heads=1,
cross_head_dim=64,
latent_heads=8,
latent_head_dim=64,
cross_attn_dropout=0.0,
self_attn_dropout=0.0,
weight_tie_layers=False,
gated_mlp=True,
).cuda()
x = torch.randn(32, 128, in_dim).cuda()
mask = torch.rand(32, 128) > 0.5
# FlashAttention only works with half-precision
# Don't forget to autocast!
with torch.autocast('cuda'):
# `out_dim` specified; averages and projects output
out = model(x)
assert out.shape == (32, out_dim)
# A input element-wise mask can be provided
# All non-True elements will be ignored
out = model(x, mask=mask)
# The raw final latents will be returned when `return_embeddings=True`
embeds = model(x, return_embeddings=True)
assert embeds.shape == (32, num_latents, latent_dim)
Performance
The Perceiver is already designed and intended as a attention architecture with sub-quadratic compute and memory complexity in comparison to the quadratic requirements of a vanilla Transformer.
A naive implementation will have $\mathcal{O}(nm)$ memory requirements for the cross-attention modules and $\mathcal{O}(n^2)$ complexity for the self-attention or latent blocks, where $n$ is the number of latent vectors (fixed hyperparameter), $m$ the number of input elements and $m \gg n$ should generally apply.
FlashAttention allows a memory usage reduction to $\mathcal{O}(n)$ for the cross-attention layers and $\mathcal{O}(n)$ for the self-attention layers. However, this only accounts for the computation of the attention mechanisms. The input sequence and corresponding keys and values within the cross-attention modules will still grow with $m$.
Until the latter starts to dominate memory usage, this implementation allows to greatly scale the input sequence length. For instance, 16x larger input lengths can be achieved in comparison to perceiver-pytorch on a RTX 4090, keeping the other hyperparameters fixed (see run_benchmarks.py
for the exact configuration).
Benchmarks
Benchmarks against other implementations (currently only perceiver-pytorch
) can be performed with:
python run_benchmarks.py
The script will create a benchmark_results.csv
. The create_plots.py
script can then be used to create plots.
The following plots have been created using a single RTX 4090 with 24GB of VRAM.
Note: The batch size for each configuration corresponds to the smallest value that works for all implementations. Especially for longer sequence lengths, this leads to decreasing GPU utilization and thus a lower speedup than theoretically possible. There are some ways to fix this, but my attempts so far have led to distorted results.
Acknowledgements
The implementation is inspired by lucidrain's Perceiver implementation and would not have been possible without Tri Dao's FlashAttention.
Planned features
These are a few features that are either planned or WIP. If you have urgent demand for some of them, feel free to write an issue:
- Perceiver IO [2] [WIP]
- Perceiver AR [3] (or an AR demo in general)
- Demos [WIP]
- Tests [WIP]
- Allow more flexible cross-attention configurations
- Benchmarks against other Perceiver implementations, e.g. DeepMind's or Krasser's
References
[1] Jaegle, Andrew, Felix Gimeno, Andrew Brock, Andrew Zisserman, Oriol Vinyals, and Joao Carreira. “Perceiver: General Perception with Iterative Attention.” arXiv, June 22, 2021. http://arxiv.org/abs/2103.03206.
[2] Jaegle, Andrew, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch, Catalin Ionescu, David Ding, Skanda Koppula, et al. “Perceiver IO: A General Architecture for Structured Inputs & Outputs.” arXiv, March 15, 2022. http://arxiv.org/abs/2107.14795.
[3] Hawthorne, Curtis, Andrew Jaegle, Cătălina Cangea, Sebastian Borgeaud, Charlie Nash, Mateusz Malinowski, Sander Dieleman, et al. “General-Purpose, Long-Context Autoregressive Modeling with Perceiver AR.” arXiv, June 14, 2022. http://arxiv.org/abs/2202.07765.
[4] Dao, Tri, Daniel Y. Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” arXiv, June 23, 2022. https://doi.org/10.48550/arXiv.2205.14135.
[5] Dao, Tri. “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning.” arXiv, July 17, 2023. https://doi.org/10.48550/arXiv.2307.08691.
[6] Su, Jianlin, Yu Lu, Shengfeng Pan, Ahmed Murtadha, Bo Wen, and Yunfeng Liu. “RoFormer: Enhanced Transformer with Rotary Position Embedding.” arXiv, August 8, 2022. https://doi.org/10.48550/arXiv.2104.09864.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file fast_perceiver-0.1.2.tar.gz
.
File metadata
- Download URL: fast_perceiver-0.1.2.tar.gz
- Upload date:
- Size: 5.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.5.1 CPython/3.10.9 Linux/5.15.0-76-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5de25558b4bae431d68a35f56a81439669db41fc4a8e4ff42be155f860503f64 |
|
MD5 | ff49593405594834b24ea982c80ecc0b |
|
BLAKE2b-256 | 701ceec4c5b968be7970ed40f74d15d889486478a2e5d24514407761249a4954 |
File details
Details for the file fast_perceiver-0.1.2-py3-none-any.whl
.
File metadata
- Download URL: fast_perceiver-0.1.2-py3-none-any.whl
- Upload date:
- Size: 6.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.5.1 CPython/3.10.9 Linux/5.15.0-76-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1bce471958828f27adc758e7843b8fc9b4096123db47fd217a2e472d632a3824 |
|
MD5 | 37033590bc179203c4e33b4d1646b035 |
|
BLAKE2b-256 | fc93cb792ada5db87c1dd3b30125889c561ab0b475246484f803c61dc2a9a6ff |