Skip to main content

Unofficial FlashAttention2 with Custom Masks

Project description

FlashAttention2 with Custom Masks 🎭

Note: This is an unofficial implementation of FlashAttention2.

For efficiency purposes, the standard implementations of FlashAttention currently do not support arbitrary custom masks. Their implementation of specific masks like causal masking for language modeling are implemented using branch logic to save memory. This repository is just a modified version of the tutorial Triton implementation of FlashAttention2 that allows the user to define a (batch of) custom mask. It modifies both the forward and backwards pass to handle custom masking (you can define a different mask per head and batch).

Original Triton code: https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html

See the original thread: https://github.com/Dao-AILab/flash-attention/issues/352

Example Setup

The relevant libraries needed to use the custom-mask FlashAttention2 kernel are below:

pip install triton>=3.0.0
pip install torch

For Viewing Benchmarking Results

Other libraries for evaluating the performance of the models is below. These are primarily for test_benchmark.py, which verifies the correctness of the implementation.

pip install pytest
pip install matplotlib
pip install pandas

To compare with the official FlashAttention and xformers.ops.memory_efficient_attention implementations, make sure to install both libraries separately (follow the instructions on these repositories).

pip install flash-attn --no-build-isolation
pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu121

Testing Correctness

There are two pytest functions in test_benchmark.py, one that tests whether a reference implementation of multi-head attention with a causal mask matches the Triton version in both the forward pass and backwards pass gradients. The second tests whether the same implementation with random masks matches the Triton version. You can modify these tests to do more rigorous correctness tests and check with pytest.

Simple Example

You can insert this module into your standard attention pipeline.

from fa2_custom_mask import flash_attention_custom_mask

B, H, L, D = 4, 16, 4096, 64
sm_scale = 1 / (D ** 0.5)

fp32_q = torch.randn(B, H, L, D).float().cuda()
fp32_k = torch.randn(B, H, L, D).float().cuda()
fp32_v = torch.randn(B, H, L, D).float().cuda()
mask = torch.randint(0, 2, (B, 1, L, L)).int().cuda()
mask = torch.broadcast_to(mask, (B, H, L, L))

out = flash_attention_custom_mask(fp32_q, fp32_k, fp32_v, mask=mask, sm_scale=sm_scale)
...
out.backward(loss)

Benchmarking

Simple benchmark against the base Triton implementation. In our custom mask version, we pass in the canonical causal mask as input (hence storing in global device memory). Running test_benchmark.py, with batch size=4, # heads=16, hidden dim=64, and sequence length N_CTX ranging from 256 to 16384 in powers of 2. You can replicate the experiments by running

pytest
python test_benchmark.py

Causal Masks and No Masks Comparisons

We compare against the original experiments and original implementation, as well as the official FlashAttention and xformers implementation (note: there seems to be a versioning issue, so it's using a different implementation. I corrected the version in the later benchmarking experiments). causal and no masking with flash attn

Causal Masks and No Masks Comparisons (with Correct xfrormers version)

We compare against the original experiments and original implementation, as well as the xformers implementation. Notably, the original implementation does well for causal masking because of some pipelining tricks and ability to not have to store masks. causal and no masking

Custom Masking Comparison

We compare directly to the xformers memory efficient attention which allows for custom masking. We generate random masks (fixed across the head dimension). custom masking

Notes and Bugs

  1. This implementation only works on Ampere devices and up. I originally tried running it on a V100 (Volta) and it failed.
  2. You need to be on triton>=3.0.0, or it'll complain about permutation indices on the value vector pointer. The torch and flash-attn libraries may force you to install triton=2.x.x, but you can just re-install triton>=3.0.0 and it should work. I may fix this manually in the future.
    • This is oddly specific, but I'm not able to have flash-attn and xformers at the same time. I had to run them separately and generate the plots.
  3. TODO: Add benchmarking for peak memory consumption and other efficiency metrics.

If time permits, I'm interested in making this implementation generalizable / changing the CUDA implementation for FA3 (if it's necessary of course). I also probably will run some more realistic workloads and see what happens.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flashattention2_custom_mask-0.1.1.tar.gz (15.4 kB view details)

Uploaded Source

Built Distribution

File details

Details for the file flashattention2_custom_mask-0.1.1.tar.gz.

File metadata

File hashes

Hashes for flashattention2_custom_mask-0.1.1.tar.gz
Algorithm Hash digest
SHA256 a2efac867f1e018459b3fe1b54b4e0525b811e9ffbc453db1b66f68de5996130
MD5 a05f9cc36291f7c41bce3e519b3a513d
BLAKE2b-256 b09cdd33b745d15ac293e3ab0138835d844a05d6f7728aab7c6d89073d561b0a

See more details on using hashes here.

File details

Details for the file flashattention2_custom_mask-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for flashattention2_custom_mask-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 b2f6aa1b5c5960c97c3993a918d8ef0b0c7c0d0b5344ca8390026a7249e24b07
MD5 62466a5180e4e6df0382e6ab3b2d444d
BLAKE2b-256 20be389bab2c3c58d9db88eaf22a7dcf50bf75593cb1ac13548ab6ae16c68522

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page