Skip to main content

Flash Attention Triton kernel with support for second-order derivatives

Project description

JVP Flash Attention

PyTorch

Description

Flash Attention Triton kernel with support for second-order derivatives, such as Jacobian-Vector Products (JVPs) and Hessian-Vector Products (HVPs)

Installation

Using pip, one can install jvp_flash_attention as follows.

# Install package
pip install jvp_flash_attention

# [OPTIONAL, for development] Install package and pre-commit hooks
pip install -e .
pre-commit install

Usage

Once installed, one can use jvp_flash_attention in place of PyTorch's scaled_dot_product_attention as follows.

import torch.nn.functional as F

from torch.nn.attention import SDPBackend, sdpa_kernel
from jvp_flash_attention.jvp_attention import attention as jvp_attention

with sdpa_kernel(SDPBackend.MATH):
  # Regular attention
  # x = F.scaled_dot_product_attention(
  #     q,
  #     k,
  #     v,
  #     attn_mask=attn_mask,
  #     dropout_p=attn_dropout_p if self.training else 0.0,
  # )

  # Flash attention
  x = jvp_attention(
      q,
      k,
      v,
      # attn_mask=attn_mask,  # NOTE: Attention masking is not yet supported
  )

Contributions or enhancements are welcome!

Tests

If you want to run the unit tests verifying the correctness of the JVP Flash Attention Triton kernel, run the following command(s).

python tests/test_jvp_attention.py --dtype {float16,bfloat16,float32}

In principle, the kernel should support ROCm systems as well, though it has not yet been tested on them. macOS is currently unsupported.

Results for float16:

==========================================================================================
BENCHMARK SUMMARY
==========================================================================================
Seq Len    Causal   Method     Time (ms)    Mem (MB)     TFLOP/s      Max Error    Grad Check
------------------------------------------------------------------------------------------
32         False    sdpa       0.551        0.64           0.0 TFLOP/s baseline     N/A       
32         False    jvp_attn   0.483        0.22           0.0 TFLOP/s 1.95e-03     ✓         

32         True     sdpa       1.067        0.65           0.0 TFLOP/s baseline     N/A       
32         True     jvp_attn   0.465        0.22           0.0 TFLOP/s 1.95e-03     ✓         

64         False    sdpa       0.552        1.41           0.0 TFLOP/s baseline     N/A       
64         False    jvp_attn   0.469        0.43           0.0 TFLOP/s 9.77e-04     ✓         

64         True     sdpa       0.875        1.42           0.0 TFLOP/s baseline     N/A       
64         True     jvp_attn   0.469        0.43           0.0 TFLOP/s 1.95e-03     ✓         

128        False    sdpa       0.533        3.28           0.0 TFLOP/s baseline     N/A       
128        False    jvp_attn   0.467        0.86           0.1 TFLOP/s 9.77e-04     ✓         

128        True     sdpa       0.860        3.35           0.0 TFLOP/s baseline     N/A       
128        True     jvp_attn   0.494        0.86           0.0 TFLOP/s 1.95e-03     ✓         

256        False    sdpa       0.538        9.69           0.2 TFLOP/s baseline     N/A       
256        False    jvp_attn   0.473        1.72           0.4 TFLOP/s 9.77e-04     ✓         

256        True     sdpa       0.870        9.94           0.0 TFLOP/s baseline     N/A       
256        True     jvp_attn   0.468        1.72           0.2 TFLOP/s 1.95e-03     ✓         

512        False    sdpa       0.575        31.88          0.6 TFLOP/s baseline     N/A       
512        False    jvp_attn   0.466        3.45           1.5 TFLOP/s 4.88e-04     ✓         

512        True     sdpa       0.914        32.88          0.2 TFLOP/s baseline     N/A       
512        True     jvp_attn   0.467        3.45           0.7 TFLOP/s 1.95e-03     ✓         

1024       False    sdpa       1.291        113.77         1.1 TFLOP/s baseline     N/A       
1024       False    jvp_attn   0.463        6.89           5.9 TFLOP/s 4.88e-04     ✓         

1024       True     sdpa       1.467        117.77         0.5 TFLOP/s baseline     N/A       
1024       True     jvp_attn   0.470        6.89           2.9 TFLOP/s 1.95e-03     ✓         

2048       False    sdpa       3.669        427.54         1.5 TFLOP/s baseline     N/A       
2048       False    jvp_attn   0.462        13.79         23.7 TFLOP/s 2.44e-04     ✓         

2048       True     sdpa       4.287        443.54         0.6 TFLOP/s baseline     N/A       
2048       True     jvp_attn   0.463        13.79         11.8 TFLOP/s 1.95e-03     ✓   

Results for bfloat16:

==========================================================================================
BENCHMARK SUMMARY
==========================================================================================
Seq Len    Causal   Method     Time (ms)    Mem (MB)     TFLOP/s      Max Error    Grad Check
------------------------------------------------------------------------------------------
32         False    sdpa       0.527        0.64           0.0 TFLOP/s baseline     N/A       
32         False    jvp_attn   0.461        0.22           0.0 TFLOP/s 1.56e-02     ✓         

32         True     sdpa       0.854        0.65           0.0 TFLOP/s baseline     N/A       
32         True     jvp_attn   0.462        0.22           0.0 TFLOP/s 1.56e-02     ✓         

64         False    sdpa       0.671        1.41           0.0 TFLOP/s baseline     N/A       
64         False    jvp_attn   0.459        0.43           0.0 TFLOP/s 7.81e-03     ✓         

64         True     sdpa       0.846        1.42           0.0 TFLOP/s baseline     N/A       
64         True     jvp_attn   0.459        0.43           0.0 TFLOP/s 1.56e-02     ✓         

128        False    sdpa       0.539        3.28           0.0 TFLOP/s baseline     N/A       
128        False    jvp_attn   0.463        0.86           0.1 TFLOP/s 7.81e-03     ✓         

128        True     sdpa       0.860        3.35           0.0 TFLOP/s baseline     N/A       
128        True     jvp_attn   0.484        0.86           0.0 TFLOP/s 1.56e-02     ✓         

256        False    sdpa       0.530        9.69           0.2 TFLOP/s baseline     N/A       
256        False    jvp_attn   0.468        1.72           0.4 TFLOP/s 3.91e-03     ✓         

256        True     sdpa       0.856        9.94           0.0 TFLOP/s baseline     N/A       
256        True     jvp_attn   0.468        1.72           0.2 TFLOP/s 1.56e-02     ✓         

512        False    sdpa       0.573        31.88          0.6 TFLOP/s baseline     N/A       
512        False    jvp_attn   0.469        3.45           1.5 TFLOP/s 3.91e-03     ✓         

512        True     sdpa       0.869        32.88          0.2 TFLOP/s baseline     N/A       
512        True     jvp_attn   0.468        3.45           0.7 TFLOP/s 1.56e-02     ✓         

1024       False    sdpa       1.290        113.77         1.1 TFLOP/s baseline     N/A       
1024       False    jvp_attn   0.462        6.89           5.9 TFLOP/s 3.91e-03     ✓         

1024       True     sdpa       1.466        117.77         0.5 TFLOP/s baseline     N/A       
1024       True     jvp_attn   0.461        6.89           3.0 TFLOP/s 1.56e-02     ✓         

2048       False    sdpa       3.673        427.54         1.5 TFLOP/s baseline     N/A       
2048       False    jvp_attn   0.462        13.79         23.7 TFLOP/s 1.95e-03     ✓         

2048       True     sdpa       4.286        443.54         0.6 TFLOP/s baseline     N/A       
2048       True     jvp_attn   0.452        13.79         12.1 TFLOP/s 3.12e-02     ✓   

Results for float32:

==========================================================================================
BENCHMARK SUMMARY
==========================================================================================
Seq Len    Causal   Method     Time (ms)    Mem (MB)     TFLOP/s      Max Error    Grad Check
------------------------------------------------------------------------------------------
32         False    sdpa       0.456        0.51           0.0 TFLOP/s baseline     N/A       
32         False    jvp_attn   0.454        0.43           0.0 TFLOP/s 7.22e-03     ✓         

32         True     sdpa       0.779        0.51           0.0 TFLOP/s baseline     N/A       
32         True     jvp_attn   0.458        0.43           0.0 TFLOP/s 6.18e-03     ✓         

64         False    sdpa       0.460        1.09           0.0 TFLOP/s baseline     N/A       
64         False    jvp_attn   0.462        0.86           0.0 TFLOP/s 7.03e-03     ✓         

64         True     sdpa       0.787        1.11           0.0 TFLOP/s baseline     N/A       
64         True     jvp_attn   0.462        0.86           0.0 TFLOP/s 6.18e-03     ✓         

128        False    sdpa       0.460        2.81           0.0 TFLOP/s baseline     N/A       
128        False    jvp_attn   0.461        1.72           0.1 TFLOP/s 5.07e-03     ✓         

128        True     sdpa       0.782        2.88           0.0 TFLOP/s baseline     N/A       
128        True     jvp_attn   0.472        1.72           0.0 TFLOP/s 6.18e-03     ✓         

256        False    sdpa       0.457        8.75           0.2 TFLOP/s baseline     N/A       
256        False    jvp_attn   0.465        3.44           0.4 TFLOP/s 3.67e-03     ✓         

256        True     sdpa       0.798        9.00           0.1 TFLOP/s baseline     N/A       
256        True     jvp_attn   0.465        3.44           0.2 TFLOP/s 5.78e-03     ✓         

512        False    sdpa       0.530        30.01          0.6 TFLOP/s baseline     N/A       
512        False    jvp_attn   0.469        6.88           1.5 TFLOP/s 2.88e-03     ✓         

512        True     sdpa       0.784        31.01          0.2 TFLOP/s baseline     N/A       
512        True     jvp_attn   0.460        6.88           0.7 TFLOP/s 5.13e-03     ✓         

1024       False    sdpa       1.207        110.02         1.1 TFLOP/s baseline     N/A       
1024       False    jvp_attn   0.467        13.77          5.9 TFLOP/s 2.61e-03     ✓         

1024       True     sdpa       1.379        115.02         0.5 TFLOP/s baseline     N/A       
1024       True     jvp_attn   0.465        13.77          2.9 TFLOP/s 5.61e-03     ✓         

2048       False    sdpa       3.435        420.04         1.6 TFLOP/s baseline     N/A       
2048       False    jvp_attn   0.496        27.54         22.1 TFLOP/s 1.56e-03     ✓         

2048       True     sdpa       4.051        436.04         0.7 TFLOP/s baseline     N/A       
2048       True     jvp_attn   0.486        27.54         11.3 TFLOP/s 6.47e-03     ✓   

License

This project is covered under the MIT License.

Citing this work

If you use the code associated with this package or otherwise find this work useful, please use GitHub's Cite this repository feature or the BibTeX below.

@software{Morehead_JVP_Flash_Attention_2025,
  author = {Morehead, Alex},
  doi = {10.5281/zenodo.17050188},
  license = {MIT},
  month = sep,
  title = {{JVP Flash Attention}},
  url = {https://github.com/amorehead/jvp_flash_attention},
  version = {0.0.1},
  year = {2025}
}

Acknowledgements

jvp_flash_attention builds upon the contributions and insights from the following sources:

We thank each and every contributor!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jvp_flash_attention-0.0.1.tar.gz (26.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jvp_flash_attention-0.0.1-py3-none-any.whl (17.6 kB view details)

Uploaded Python 3

File details

Details for the file jvp_flash_attention-0.0.1.tar.gz.

File metadata

  • Download URL: jvp_flash_attention-0.0.1.tar.gz
  • Upload date:
  • Size: 26.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for jvp_flash_attention-0.0.1.tar.gz
Algorithm Hash digest
SHA256 704771efe0adad19719181e473bd7220ffdd24c8c45020abe82fce2fc18c1eab
MD5 99e083429101e18d248f966abc42b64f
BLAKE2b-256 d386b42b1af6865de6128b6f4e435b90a5140b3fc1e56d877016203ca9192a90

See more details on using hashes here.

File details

Details for the file jvp_flash_attention-0.0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for jvp_flash_attention-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 8b50028d0b9daf533caaa104de0de1323bbcfa5092e50648c09a8eb19658511c
MD5 7409b471fbcdc671a2b86458cdbe7b78
BLAKE2b-256 53d4b0c54228b02d48df9c9167d86bfeabd2e172e760f8e9ec6ba890295af57a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page