Skip to main content

Ring attention implementation with flash attention.

Project description

Ring Flash Attention

This repo implements RingAttention using FlashAttention. The current implementation supports:

  • varlen (packing samples) api, corresponding to flash_attn_varlen_func:
    • ring_flash_attn_varlen_func: A basic implementation of ring attention.
    • zigzag_ring_flash_attn_varlen_func: an more compute-balanced version of ring attention. More details in issue#2.
    • llama3_flash_attn_varlen_func: The context parallelism used in llama3 tech report with extra design for varlen and low memory overhead. Although technically not ring attention, this is recommended for most varlen use cases, as it offers a less intrusive alternative for training frameworks with fewer data manipulations and better arithmetic precision.
  • batch api, corresponding to flash_attn_func:
    • ring_flash_attn_func: basic ring attention.
    • zigzag_ring_flash_attn_func: An more compute balanced version of ring attention, see issue#2.
    • stripe_flash_attn_func: Stripe attention version of ring_flash_attn_func, the block size is set to 1 to use flash_attn api, see: https://arxiv.org/abs/2311.09431
  • huggingface model adapter. Here is an example to use the adapter:
# torchrun --nproc_per_node=2 this_script.py
import torch
from ring_flash_attn import substitute_hf_flash_attn, update_ring_flash_attn_params
from torch import distributed as dist
from transformers import AutoModelForCausalLM, AutoTokenizer

def main():
    # Initialize distributed training
    dist.init_process_group(backend="nccl")

    # Get rank and world size
    rank = dist.get_rank()
    world_size = dist.get_world_size()

    # Set device
    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}")

    # Load model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(
        "Qwen/Qwen3-0.6B", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, device_map=device
    )
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")

    # Create group and substitute flash attention
    group = dist.new_group(ranks=range(world_size), backend="nccl")
    substitute_hf_flash_attn(group, heads_k_stride=1)

    # Get the ring attention rank
    ring_attn_rank = dist.get_rank(group=group)  # only one group for ring attention here: this should be the same as rank

    # Tokenize input and prepare position IDs
    input_ids = tokenizer(["Lorem ipsum dolor sit", "amet, consectetur adipiscing", "elit, sed do"]).input_ids
    lengths = [len(seq) for seq in input_ids]
    input_ids = torch.cat([torch.tensor(seq, device=device) for seq in input_ids]).unsqueeze(0)
    position_ids = torch.cat([torch.arange(length, device=device) for length in lengths]).unsqueeze(0)
    
    # Compute cu_seqlens and update parameters
    cu_seqlens = torch.cat([torch.tensor([0], device=device), torch.cumsum(torch.tensor(lengths, device=device), dim=0)]).to(torch.int32)
    update_ring_flash_attn_params(cu_seqlens, group)

    # Chunk input_ids and position_ids
    input_ids = torch.chunk(input_ids, world_size, dim=1)[ring_attn_rank]
    position_ids = torch.chunk(position_ids, world_size, dim=1)[ring_attn_rank]
    
    output = model(input_ids=input_ids, position_ids=position_ids)

    # Clean up
    dist.destroy_process_group()


if __name__ == "__main__":
    main()

Note that

  • Each function includes *_func, *_kvpacked_func, *_qkvpacked_func variants.
  • The varlen versions (except the llama3 version) only support passing one cu_seqlens.

Performance Summary

The following table summarizes the performance of the implemented APIs:

batch api GPU theoretic
flash_attn
ring_attn zigzag_ring stripe_attn
fwd only (iter/sec) 8xH800 591.5 / 8 = 73.9 38.5 63.0 55.0
52.1% 85.2% 74.4%
fwd + bwd (iter/sec) 8xH800 154.7 / 8 = 19.3 10.4 17.4 16.0
53.9% 90.2% 82.9%
fwd only (iter/sec) 8xA100 373.4 / 8 = 46.7 24.0 38.2 32.5
51.4% 81.7% 69.6%
fwd + bwd (iter/sec) 8xA100 94.7 / 8 = 11.8 6.2 10.6 9.75
52.5% 89.8% 82.6%
varlen api GPU theoretic
flash_attn
ring_attn zigzag_ring llama3_attn
fwd only (iter/sec) 8xH800 852.4 / 8 = 106.6 52.4 74.8 60.8
49.1% 70.2% 57.0%
fwd + bwd (iter/sec) 8xH800 225.4 / 8 = 28.2 14.4 21.4 16.4
51.1% 75.9% 58.1%
fwd only (iter/sec) 8xA100 532.3 / 8 = 66.5 33.1 47.9 34.3
49.8% 72.0% 51.6%
fwd + bwd (iter/sec) 8xA100 133.8 / 8 = 16.7 8.7 13.4 9.7
52.1% 80.2% 58.0%

Note that

  • The code of the benchmark is in benchmark, its configuration matches the Meta-Llama-3.1-8B setting, with a total sequence of length 8k per GPU.
  • When running the benchmark with with 8 gpu, the flash attn code is running with 1/8 computation of ring attention, as flash attn code is running 8*1^2, while the ring attn code is running 1*8^2.
  • NVLink between GPUs are required for high performance.
  • Please remember to adapt the RoPE offset for different api.

Installation

pip install ring-flash-attn

or use the following command to build from source:

git clone https://github.com/zhuzilin/ring-flash-attention.git
cd ring-flash-attention
pip install .

TODOs

  • Implement ring_flash_attn_varlen_qkvpacked_func
  • Implement zigzag_ring_flash_attn_qkvpacked_func issue#2
  • Implement stripe_flash_attn_qkvpacked_func
  • Implement zigzag_ring_flash_attn_varlen_qkvpacked_func
  • Implement *_kvpacked_func and *_func variant for all APIs
  • Optimize *_varlen_func Implement llama3_flash_attn_varlen_func
  • Add an example to train llama Implement adapter for huggingface model
  • Implement zigzag_llama3_flash_attn_varlen_func

Test

torchrun --nproc_per_node 8 test/test_llama3_flash_attn_varlen_func.py
torchrun --nproc_per_node 8 test/test_ring_flash_attn_func.py
torchrun --nproc_per_node 8 test/test_ring_flash_attn_varlen_func.py
torchrun --nproc_per_node 8 test/test_zigzag_ring_flash_attn_func.py
torchrun --nproc_per_node 8 test/test_zigzag_ring_flash_attn_varlen_func.py
torchrun --nproc_per_node 8 test/test_stripe_flash_attn_func.py

Benchmark

torchrun --nproc_per_node 8 benchmark/benchmark_kvpacked_func.py
torchrun --nproc_per_node 8 benchmark/benchmark_varlen_kvpacked_func.py

Known Limitations

There are some arithmetic errors with the current implementation. The reason for them is probably that flash attention will return bf16 value for each block, so we cannot accumluate the values with the original fp32 ones.

And also because we need to save extra fp32 buffer during computation, the memory usage would be higher than theoretic limit.

Also,

  • dropout is not supported at the moment, because it's hard to save all the rng_states.
  • window_size is not supported, because it will be really tricky to implement a varlen version with window_size.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ring_flash_attn-0.1.6.tar.gz (22.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ring_flash_attn-0.1.6-py3-none-any.whl (25.4 kB view details)

Uploaded Python 3

File details

Details for the file ring_flash_attn-0.1.6.tar.gz.

File metadata

  • Download URL: ring_flash_attn-0.1.6.tar.gz
  • Upload date:
  • Size: 22.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.3

File hashes

Hashes for ring_flash_attn-0.1.6.tar.gz
Algorithm Hash digest
SHA256 1d1a4b70eb6491a00a69a51b247cdb8a479d6bd54765ea4bfd423817b5b301a0
MD5 d22ae35e6504b3f31ffbb27bca8aae47
BLAKE2b-256 8e67485a1607a7bc658d3262ec753106820c6f59289d7f767b110781fac5d074

See more details on using hashes here.

File details

Details for the file ring_flash_attn-0.1.6-py3-none-any.whl.

File metadata

File hashes

Hashes for ring_flash_attn-0.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 c260cf87c27a736a21fb4a49a992edcd7d5320f882bd93188171f48ca52f6aa9
MD5 0fb0b0c6ae46f90a6c8593bfb0a6fa7e
BLAKE2b-256 c778231176338475e178e3ea4a2c30f92c2fe962fb18ce48ab1d1d21581466d8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page