Flash Attention Triton kernel with support for second-order derivatives
Project description
Description
Flash Attention Triton kernel with support for second-order derivatives, such as Jacobian-Vector Products (JVPs) and Hessian-Vector Products (HVPs)
Installation
Using pip, one can install jvp_flash_attention as follows.
# Install package
pip install jvp_flash_attention
# [OPTIONAL, for development] Install package and pre-commit hooks
pip install -e .
pre-commit install
Usage
Once installed, one can use jvp_flash_attention in place of PyTorch's scaled_dot_product_attention as follows.
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel
from jvp_flash_attention.jvp_attention import attention as jvp_attention
with sdpa_kernel(SDPBackend.MATH):
# Regular attention
# x = F.scaled_dot_product_attention(
# q,
# k,
# v,
# attn_mask=attn_mask,
# dropout_p=attn_dropout_p if self.training else 0.0,
# )
# Flash attention
x = jvp_attention(
q,
k,
v,
# attn_mask=attn_mask, # NOTE: Attention masking is not yet supported
)
Contributions or enhancements are welcome!
Tests
If you want to run the unit tests verifying the correctness of the JVP Flash Attention Triton kernel, run the following command(s).
python tests/test_jvp_attention.py --dtype {float16,bfloat16,float32}
In principle, the kernel should support ROCm systems as well, though it has not yet been tested on them. macOS is currently unsupported.
Results for float16:
==========================================================================================
BENCHMARK SUMMARY
==========================================================================================
Seq Len Causal Method Time (ms) Mem (MB) TFLOP/s Max Error Grad Check
------------------------------------------------------------------------------------------
32 False sdpa 0.551 0.64 0.0 TFLOP/s baseline N/A
32 False jvp_attn 0.483 0.22 0.0 TFLOP/s 1.95e-03 ✓
32 True sdpa 1.067 0.65 0.0 TFLOP/s baseline N/A
32 True jvp_attn 0.465 0.22 0.0 TFLOP/s 1.95e-03 ✓
64 False sdpa 0.552 1.41 0.0 TFLOP/s baseline N/A
64 False jvp_attn 0.469 0.43 0.0 TFLOP/s 9.77e-04 ✓
64 True sdpa 0.875 1.42 0.0 TFLOP/s baseline N/A
64 True jvp_attn 0.469 0.43 0.0 TFLOP/s 1.95e-03 ✓
128 False sdpa 0.533 3.28 0.0 TFLOP/s baseline N/A
128 False jvp_attn 0.467 0.86 0.1 TFLOP/s 9.77e-04 ✓
128 True sdpa 0.860 3.35 0.0 TFLOP/s baseline N/A
128 True jvp_attn 0.494 0.86 0.0 TFLOP/s 1.95e-03 ✓
256 False sdpa 0.538 9.69 0.2 TFLOP/s baseline N/A
256 False jvp_attn 0.473 1.72 0.4 TFLOP/s 9.77e-04 ✓
256 True sdpa 0.870 9.94 0.0 TFLOP/s baseline N/A
256 True jvp_attn 0.468 1.72 0.2 TFLOP/s 1.95e-03 ✓
512 False sdpa 0.575 31.88 0.6 TFLOP/s baseline N/A
512 False jvp_attn 0.466 3.45 1.5 TFLOP/s 4.88e-04 ✓
512 True sdpa 0.914 32.88 0.2 TFLOP/s baseline N/A
512 True jvp_attn 0.467 3.45 0.7 TFLOP/s 1.95e-03 ✓
1024 False sdpa 1.291 113.77 1.1 TFLOP/s baseline N/A
1024 False jvp_attn 0.463 6.89 5.9 TFLOP/s 4.88e-04 ✓
1024 True sdpa 1.467 117.77 0.5 TFLOP/s baseline N/A
1024 True jvp_attn 0.470 6.89 2.9 TFLOP/s 1.95e-03 ✓
2048 False sdpa 3.669 427.54 1.5 TFLOP/s baseline N/A
2048 False jvp_attn 0.462 13.79 23.7 TFLOP/s 2.44e-04 ✓
2048 True sdpa 4.287 443.54 0.6 TFLOP/s baseline N/A
2048 True jvp_attn 0.463 13.79 11.8 TFLOP/s 1.95e-03 ✓
Results for bfloat16:
==========================================================================================
BENCHMARK SUMMARY
==========================================================================================
Seq Len Causal Method Time (ms) Mem (MB) TFLOP/s Max Error Grad Check
------------------------------------------------------------------------------------------
32 False sdpa 0.527 0.64 0.0 TFLOP/s baseline N/A
32 False jvp_attn 0.461 0.22 0.0 TFLOP/s 1.56e-02 ✓
32 True sdpa 0.854 0.65 0.0 TFLOP/s baseline N/A
32 True jvp_attn 0.462 0.22 0.0 TFLOP/s 1.56e-02 ✓
64 False sdpa 0.671 1.41 0.0 TFLOP/s baseline N/A
64 False jvp_attn 0.459 0.43 0.0 TFLOP/s 7.81e-03 ✓
64 True sdpa 0.846 1.42 0.0 TFLOP/s baseline N/A
64 True jvp_attn 0.459 0.43 0.0 TFLOP/s 1.56e-02 ✓
128 False sdpa 0.539 3.28 0.0 TFLOP/s baseline N/A
128 False jvp_attn 0.463 0.86 0.1 TFLOP/s 7.81e-03 ✓
128 True sdpa 0.860 3.35 0.0 TFLOP/s baseline N/A
128 True jvp_attn 0.484 0.86 0.0 TFLOP/s 1.56e-02 ✓
256 False sdpa 0.530 9.69 0.2 TFLOP/s baseline N/A
256 False jvp_attn 0.468 1.72 0.4 TFLOP/s 3.91e-03 ✓
256 True sdpa 0.856 9.94 0.0 TFLOP/s baseline N/A
256 True jvp_attn 0.468 1.72 0.2 TFLOP/s 1.56e-02 ✓
512 False sdpa 0.573 31.88 0.6 TFLOP/s baseline N/A
512 False jvp_attn 0.469 3.45 1.5 TFLOP/s 3.91e-03 ✓
512 True sdpa 0.869 32.88 0.2 TFLOP/s baseline N/A
512 True jvp_attn 0.468 3.45 0.7 TFLOP/s 1.56e-02 ✓
1024 False sdpa 1.290 113.77 1.1 TFLOP/s baseline N/A
1024 False jvp_attn 0.462 6.89 5.9 TFLOP/s 3.91e-03 ✓
1024 True sdpa 1.466 117.77 0.5 TFLOP/s baseline N/A
1024 True jvp_attn 0.461 6.89 3.0 TFLOP/s 1.56e-02 ✓
2048 False sdpa 3.673 427.54 1.5 TFLOP/s baseline N/A
2048 False jvp_attn 0.462 13.79 23.7 TFLOP/s 1.95e-03 ✓
2048 True sdpa 4.286 443.54 0.6 TFLOP/s baseline N/A
2048 True jvp_attn 0.452 13.79 12.1 TFLOP/s 3.12e-02 ✓
Results for float32:
==========================================================================================
BENCHMARK SUMMARY
==========================================================================================
Seq Len Causal Method Time (ms) Mem (MB) TFLOP/s Max Error Grad Check
------------------------------------------------------------------------------------------
32 False sdpa 0.456 0.51 0.0 TFLOP/s baseline N/A
32 False jvp_attn 0.454 0.43 0.0 TFLOP/s 7.22e-03 ✓
32 True sdpa 0.779 0.51 0.0 TFLOP/s baseline N/A
32 True jvp_attn 0.458 0.43 0.0 TFLOP/s 6.18e-03 ✓
64 False sdpa 0.460 1.09 0.0 TFLOP/s baseline N/A
64 False jvp_attn 0.462 0.86 0.0 TFLOP/s 7.03e-03 ✓
64 True sdpa 0.787 1.11 0.0 TFLOP/s baseline N/A
64 True jvp_attn 0.462 0.86 0.0 TFLOP/s 6.18e-03 ✓
128 False sdpa 0.460 2.81 0.0 TFLOP/s baseline N/A
128 False jvp_attn 0.461 1.72 0.1 TFLOP/s 5.07e-03 ✓
128 True sdpa 0.782 2.88 0.0 TFLOP/s baseline N/A
128 True jvp_attn 0.472 1.72 0.0 TFLOP/s 6.18e-03 ✓
256 False sdpa 0.457 8.75 0.2 TFLOP/s baseline N/A
256 False jvp_attn 0.465 3.44 0.4 TFLOP/s 3.67e-03 ✓
256 True sdpa 0.798 9.00 0.1 TFLOP/s baseline N/A
256 True jvp_attn 0.465 3.44 0.2 TFLOP/s 5.78e-03 ✓
512 False sdpa 0.530 30.01 0.6 TFLOP/s baseline N/A
512 False jvp_attn 0.469 6.88 1.5 TFLOP/s 2.88e-03 ✓
512 True sdpa 0.784 31.01 0.2 TFLOP/s baseline N/A
512 True jvp_attn 0.460 6.88 0.7 TFLOP/s 5.13e-03 ✓
1024 False sdpa 1.207 110.02 1.1 TFLOP/s baseline N/A
1024 False jvp_attn 0.467 13.77 5.9 TFLOP/s 2.61e-03 ✓
1024 True sdpa 1.379 115.02 0.5 TFLOP/s baseline N/A
1024 True jvp_attn 0.465 13.77 2.9 TFLOP/s 5.61e-03 ✓
2048 False sdpa 3.435 420.04 1.6 TFLOP/s baseline N/A
2048 False jvp_attn 0.496 27.54 22.1 TFLOP/s 1.56e-03 ✓
2048 True sdpa 4.051 436.04 0.7 TFLOP/s baseline N/A
2048 True jvp_attn 0.486 27.54 11.3 TFLOP/s 6.47e-03 ✓
License
This project is covered under the MIT License.
Citing this work
If you use the code associated with this package or otherwise find this work useful, please use GitHub's Cite this repository feature or the BibTeX below.
@software{Morehead_JVP_Flash_Attention_2025,
author = {Morehead, Alex},
doi = {10.5281/zenodo.17050188},
license = {MIT},
month = sep,
title = {{JVP Flash Attention}},
url = {https://github.com/amorehead/jvp_flash_attention},
version = {0.0.1},
year = {2025}
}
Acknowledgements
jvp_flash_attention builds upon the contributions and insights from the following sources:
We thank each and every contributor!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jvp_flash_attention-0.0.1.tar.gz.
File metadata
- Download URL: jvp_flash_attention-0.0.1.tar.gz
- Upload date:
- Size: 26.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
704771efe0adad19719181e473bd7220ffdd24c8c45020abe82fce2fc18c1eab
|
|
| MD5 |
99e083429101e18d248f966abc42b64f
|
|
| BLAKE2b-256 |
d386b42b1af6865de6128b6f4e435b90a5140b3fc1e56d877016203ca9192a90
|
File details
Details for the file jvp_flash_attention-0.0.1-py3-none-any.whl.
File metadata
- Download URL: jvp_flash_attention-0.0.1-py3-none-any.whl
- Upload date:
- Size: 17.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8b50028d0b9daf533caaa104de0de1323bbcfa5092e50648c09a8eb19658511c
|
|
| MD5 |
7409b471fbcdc671a2b86458cdbe7b78
|
|
| BLAKE2b-256 |
53d4b0c54228b02d48df9c9167d86bfeabd2e172e760f8e9ec6ba890295af57a
|