Skip to main content

Flash Attention Implementation with Multiple Backend Support and Sharding This module provides a flexible implementation of Flash Attention with support for different backends (GPU, TPU, CPU) and platforms (Triton, Pallas, JAX).

Project description

jax-flash-attn2

A flexible and efficient implementation of Flash Attention 2.0 for JAX, supporting multiple backends (GPU/TPU/CPU) and platforms (Triton/Pallas/JAX).

Features

  • 🚀 Multiple backend support: GPU, TPU, and CPU
  • 🔧 Multiple platform implementations: Triton, Pallas, and JAX
  • ⚡ Efficient caching of attention instances
  • 🔄 Support for Grouped Query Attention (GQA) and headdims up to 256.
  • 📊 JAX sharding-friendly implementation
  • 🎯 Automatic platform selection based on backend
  • 🧩 Compatible with existing JAX mesh patterns

Installation

pip install jax-flash-attn2

Quick Start

from jax_flash_attn2 import get_cached_flash_attention

# Get a cached attention instance
attention = get_cached_flash_attention(
	backend="gpu", # 'gpu', 'tpu', or 'cpu'
	platform="triton", # 'triton', 'pallas', or 'jax'
	blocksize_q=64, # BLOCK SIZE Q
	blocksize_k=128, # BLOCK SIZE K
	softmax_scale=headdim ** -0.5 # Optional scaling factor
)

# Use with your tensors
outputs = attention(
	query=query_states,
	key=key_states,
	value=value_states,
	bias=attention_bias, # Optional
)

Usage with JAX Sharding

with mesh:
	attention_outputs = get_cached_flash_attention(
		backend="gpu",
		platform="triton",
		blocksize_q=128,
		blocksize_k=128,
		softmax_scale=None,
	)(
		query=with_sharding_constraint(query_states, qps).astype(dtype),
		key=with_sharding_constraint(key_states, kps).astype(dtype),
		value=with_sharding_constraint(value_states, vps).astype(dtype),
		bias=with_sharding_constraint(bias, bps).astype(dtype),
	)

Supported Configurations

Backends

  • gpu: CUDA-capable GPUs
  • tpu: Google Cloud TPUs
  • cpu: CPU fallback

Platforms

  • triton: Optimized for NVIDIA GPUs
  • pallas: Optimized for TPUs and supported on GPUs
  • jax: Universal fallback, supports all backends

Valid Backend-Platform Combinations

Backend Supported Platforms
GPU Triton, Pallas, JAX
TPU Pallas, JAX
CPU JAX

Advanced Configuration

Custom Block Sizes

attention = get_cached_flash_attention(
    backend="gpu",
    platform="triton",
    blocksize_q=128,    # Customize query block size
    blocksize_k=128,    # Customize key block size
    softmax_scale=1.0,  # Custom softmax scaling
)

Environment Variables

  • FORCE_MHA: Set to "true", "1", or "on" to force using MHA implementation even for GQA cases

Performance Tips

  1. Block Sizes: Default block sizes (128) work well for most cases, but you might want to tune them for your specific hardware and model architecture.

  2. Platform Selection:

    • For NVIDIA GPUs: prefer triton
    • For TPUs: prefer pallas
    • For CPU or fallback: use jax
  3. Caching: The get_cached_flash_attention function automatically caches instances based on parameters. No need to manage caching manually.

Requirements

  • JAX
  • einops
  • chex
  • jax.experimental.pallas (for TPU support)
  • triton (for GPU optimized implementation)

Limitations

  • Triton platform is only available on NVIDIA GPUs.
  • Some platform-backend combinations are not supported (see table above).
  • Custom attention masks are not yet supported (use bias instead).

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

Citation

If you use this implementation in your research, please cite:

@software{jax_flash_attn2,
    title = {JAX Flash Attention 2.0},
    year = {2024},
    url = {https://github.com/erfanzar/jax-flash-attn2}
}

Acknowledgments And Refrences

  1. This implementation (MHA) is based on:
  1. Custom Triton Uses JAX-Triton

  2. All of kernels are copied from EasyDeL

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_flash_attn2-0.0.1.tar.gz (35.4 kB view details)

Uploaded Source

Built Distribution

jax_flash_attn2-0.0.1-py3-none-any.whl (42.8 kB view details)

Uploaded Python 3

File details

Details for the file jax_flash_attn2-0.0.1.tar.gz.

File metadata

  • Download URL: jax_flash_attn2-0.0.1.tar.gz
  • Upload date:
  • Size: 35.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.4 CPython/3.10.12 Linux/5.15.153.1-microsoft-standard-WSL2

File hashes

Hashes for jax_flash_attn2-0.0.1.tar.gz
Algorithm Hash digest
SHA256 c76947468451f41d4c9d2fe59c868c13bffdb7e96d05354491567a542f48c815
MD5 78211d0cf9ed7c68ca72e7e2279bd22a
BLAKE2b-256 39be7edd549bc129222063dc504a50943db32024d37d72f0987ff469f79c7418

See more details on using hashes here.

File details

Details for the file jax_flash_attn2-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: jax_flash_attn2-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 42.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.4 CPython/3.10.12 Linux/5.15.153.1-microsoft-standard-WSL2

File hashes

Hashes for jax_flash_attn2-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 161f2baf1bc3a11e80fa30717521769267c5840cabb39af2b5b012f9e1e0ebdb
MD5 8d7ca7e9095345343bca1488389d2743
BLAKE2b-256 33bf6f165b9632be5dd07aee61201c2e2a29bb857eab8e0ecbb94b5d04de2c98

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page