Skip to main content

Ring attention implementation with flash attention.

Project description

Ring Flash Attention

This repo implements the RingAttention with FlashAttention. Currently, this repo implements:

  • varlen api, corresponding to flash_attn_varlen_func:
    • ring_flash_attn_varlen_func: naive ring attention.
    • zigzag_ring_flash_attn_varlen_func: an more compute balanced version of ring attention, see issue#2.
    • llama3_flash_attn_varlen_func: the context parallelism used in llama3 tech report with extra design for varlen and low memory overhead.
  • batch api, corresponding to flash_attn_func:
    • ring_flash_attn_func: naive ring attention.
    • zigzag_ring_flash_attn_func: an more compute balanced version of ring attention, see issue#2.
    • stripe_flash_attn_func: stripe attention version of ring_flash_attn_func, the block size is set to 1 to use flash_attn api, see: https://arxiv.org/abs/2311.09431
  • huggingface model adapter. Here is an example to use the adapter: OpenRLHF/OpenRLHF/pull#439.

Note that

  • all function has the *_func, *_kvpacked_func, *_qkvpacked_func variant implemented.
  • the varlen versions only support passing one cu_seqlens.

The current performance is:

batch api GPU theoretic
flash_attn
ring_attn zigzag_ring stripe_attn
fwd only (iter/sec) 8xH800 591.5 / 8 = 73.9 38.5 63.0 55.0
52.1% 85.2% 74.4%
fwd + bwd (iter/sec) 8xH800 154.7 / 8 = 19.3 10.4 17.4 16.0
53.9% 90.2% 82.9%
fwd only (iter/sec) 8xA100 373.4 / 8 = 46.7 24.0 38.2 32.5
51.4% 81.7% 69.6%
fwd + bwd (iter/sec) 8xA100 94.7 / 8 = 11.8 6.2 10.6 9.75
52.5% 89.8% 82.6%
varlen api GPU theoretic
flash_attn
ring_attn zigzag_ring llama3_attn
fwd only (iter/sec) 8xH800 852.4 / 8 = 106.6 52.4 74.8 60.8
49.1% 70.2% 57.0%
fwd + bwd (iter/sec) 8xH800 225.4 / 8 = 28.2 14.4 21.4 16.4
51.1% 75.9% 58.1%
fwd only (iter/sec) 8xA100 532.3 / 8 = 66.5 33.1 47.9 34.3
49.8% 72.0% 51.6%
fwd + bwd (iter/sec) 8xA100 133.8 / 8 = 16.7 8.7 13.4 9.7
52.1% 80.2% 58.0%

Note that

  • The code of the benchmark is in benchmark, the config of the attention is set to the same as Meta-Llama-3.1-8B and each GPU will run with a total sequence of length 8k.
  • When running the benchmark with with 8 gpu, the flash attn code is running with 1/8 computation of ring attention, as flash attn code is running $81^2$, while the ring attn code is running $18^2$.
  • NVLink between GPUs are required for high performance.
  • Please remember to adapt the RoPE offset for different api.
  • Technically, the llama3 series of APIs is not ring attention and will bring memory overhead, but its communication pattern is more friendly to GPU cluster and the arithmetic errors is lower.

Installation

pip install ring-flash-attn

or use the following command to build from source:

git clone https://github.com/zhuzilin/ring-flash-attention.git
cd ring-flash-attention
pip install .

Limits

There are some arithmetic errors with the current implementation. The reason for them is probably that flash attention will return bf16 value for each block, so we cannot accumluate the values with the original fp32 ones.

And also because we need to save extra fp32 buffer during computation, the memory usage would be higher than theoretic limit.

TODOs

  • Implement ring_flash_attn_varlen_qkvpacked_func
  • Implement zigzag_ring_flash_attn_qkvpacked_func issue#2
  • Implement stripe_flash_attn_qkvpacked_func
  • Implement zigzag_ring_flash_attn_varlen_qkvpacked_func
  • Implement *_kvpacked_func and *_func variant for all APIs
  • Optimize *_varlen_func Implement llama3_flash_attn_varlen_func
  • Add an example to train llama Implement adapter for huggingface model
  • Implement zigzag_llama3_flash_attn_varlen_func

Test

torchrun --nproc_per_node 8 test/test_llama3_flash_attn_varlen_func.py
torchrun --nproc_per_node 8 test/test_ring_flash_attn_func.py
torchrun --nproc_per_node 8 test/test_ring_flash_attn_varlen_func.py
torchrun --nproc_per_node 8 test/test_zigzag_ring_flash_attn_func.py
torchrun --nproc_per_node 8 test/test_zigzag_ring_flash_attn_varlen_func.py
torchrun --nproc_per_node 8 test/test_stripe_flash_attn_func.py

Benchmark

torchrun --nproc_per_node 8 benchmark/benchmark_kvpacked_func.py
torchrun --nproc_per_node 8 benchmark/benchmark_varlen_kvpacked_func.py

Known Limits

  • dropout is not supported at the moment, because it's hard to save all the rng_states.
  • window_size is not supported, because it will be really tricky to implement a varlen version with window_size.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ring_flash_attn-0.1.1.tar.gz (19.4 kB view details)

Uploaded Source

Built Distribution

ring_flash_attn-0.1.1-py3-none-any.whl (22.9 kB view details)

Uploaded Python 3

File details

Details for the file ring_flash_attn-0.1.1.tar.gz.

File metadata

  • Download URL: ring_flash_attn-0.1.1.tar.gz
  • Upload date:
  • Size: 19.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.8

File hashes

Hashes for ring_flash_attn-0.1.1.tar.gz
Algorithm Hash digest
SHA256 8d3f87d394617338e097a5963d5a312854d4d5f23bc2b1d22126df08eadab18d
MD5 d47d7479f45cc1eb742c986de3a4aea4
BLAKE2b-256 28b0b5369f8a5d51a0e8d76f0e75045f38cc5483d66a67310dbb9817e735d7f6

See more details on using hashes here.

File details

Details for the file ring_flash_attn-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for ring_flash_attn-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 68630d49bf7644d9619258c0e06c2790adc92160ac79a0db636a4d5bdde069eb
MD5 930f81289a43fa169461c9d5013b605f
BLAKE2b-256 568b02015fbe0eece75068a7aab517a2f82ca5e3ae1a0590462ce58ce54bb4e3

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page