Skip to main content

FlashAttention-3 forward

Project description

Flash-Attention-3 Forward-Only Kernel

This repository bundles the Flash-Attention-3 forward-only kernel and the tooling required to build a lightweight Python wheel. It is intended for inference scenarios where backward operators and optional features are unnecessary.

Highlights

  • Ships only the Flash-Attention-3 forward path while disabling backward kernels, local attention, paged KV cache, FP16 kernels, and other extras to minimize the wheel size.
  • Applies a patch that renames the public interface to fa3_fwd_interface, making the forward kernel easy to import from Python.

Prerequisites(same as upstream)

  • Python: 3.9 or later
  • PyTorch: 2.10
  • Build dependencies: ninja, packaging, wheel

Quick Start

  1. Clone the repository and initialize submodules:

    git clone --recursive <repo-url>
    cd fa3-fwd
    # If --recursive was omitted during clone, run:
    git submodule update --init --recursive
    
  2. Create a Python virtual environment and install dependencies:

    uv venv --python 3.12 --seed
    source .venv/bin/activate
    uv pip install -r requirements.txt
    
  3. Build the forward-only wheel:

    bash build_fa3.sh
    

    The script:

    • Sources set_compile_env.sh to compute MAX_JOBS and NVCC_THREADS
    • Applies the custom patch and interface rename inside the Flash-Attention submodule
    • Runs python setup.py bdist_wheel under flash-attention/hopper
  4. Install the generated wheel (example):

    pip install build/*.whl
    

Python Usage Example

import torch
from fa3_fwd_interface import flash_attn_func

# Inputs must already live on CUDA and satisfy Flash-Attention-3 constraints
out = flash_attn_func(q, k, v, causal=True)

This package exposes only the forward kernel. For backward support or additional features, depend on the upstream Flash-Attention project instead.

Troubleshooting

  • Out-of-memory during compilation: The build script already throttles concurrency, but you can enforce MAX_JOBS=1 NVCC_THREADS=1 before running bash build_fa3.sh.
  • CUDA mismatch errors: Confirm that nvcc --version aligns with torch.version.cuda.

Repository Layout

Customize further by editing environment variables in the build script or modifying the submodule before the patch is applied (for example to re-enable additional datatypes or kernels).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

fa3_fwd-0.0.3-cp39-abi3-manylinux_2_24_x86_64.whl (27.6 MB view details)

Uploaded CPython 3.9+manylinux: glibc 2.24+ x86-64

fa3_fwd-0.0.3-cp39-abi3-manylinux_2_24_aarch64.whl (27.9 MB view details)

Uploaded CPython 3.9+manylinux: glibc 2.24+ ARM64

File details

Details for the file fa3_fwd-0.0.3-cp39-abi3-manylinux_2_24_x86_64.whl.

File metadata

File hashes

Hashes for fa3_fwd-0.0.3-cp39-abi3-manylinux_2_24_x86_64.whl
Algorithm Hash digest
SHA256 44d26c6cfc69dde5a0f9c6bebbccbb96ed20e97bc8b9ddda638b0bf62e91b93d
MD5 dc536dd185e33a8eeb88d3247e53a88f
BLAKE2b-256 eda6a4a31f4e500a3ac787cadf4c172aaf9fc74fb81f6ed78778fc8158384f00

See more details on using hashes here.

File details

Details for the file fa3_fwd-0.0.3-cp39-abi3-manylinux_2_24_aarch64.whl.

File metadata

File hashes

Hashes for fa3_fwd-0.0.3-cp39-abi3-manylinux_2_24_aarch64.whl
Algorithm Hash digest
SHA256 1e0e51f2b4b094e8dd0526a5e4291a0f2d0366cc55e89b27c6b8bb469dd05cfa
MD5 d034d9592bde974317a6d4c5c0b7a267
BLAKE2b-256 75b061b7888b1699efc8299f2027c7d18130a35fc7f3819ad98c2f85317d3b1d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page