Skip to main content

A package for long context attention

Project description

YunChang: A Unified Sequence Parallel (USP) Attention for Long Context LLM Model Training and Inference.

[Tech Report] USP: A Unified Sequence Parallelism Approach for Long Context Generative AI

This repo provides a sequence parallel approach that synergizes the strengths of two popular distributed attentions, i.e. DeepSpeed-Ulysses-Attention and Ring-Attention, delivering a more general and stronger versatility and better performance. The project is built on zhuzilin/ring-flash-attention and refers to the DeepSpeed-Ulysses.

USP has been applied in NVIDIA/TransformerEngine AttnFuncWithCPAndKVP2P. You can use it in API attn_forward_func_with_cp.

Why not apply Ulysses and Ring Attention Individually?

  • Ulysses is sensitive to the number of attention heads. The parallelism degree in Ulysses cannot exceed the number of heads. Consequently, it is not suitable for GQA (Grouped Query Attention) and MQA (Multi-Query Attention) scenarios. For instance, Ulysses does not operate effectively with a single head. In addition, since Tensor Parallelism also requires division across the head number dimension, achieving compatibility between Ulysses and TP can be challenging.

  • Ring-Attention is ineffient than Ulysses in computation and communication. Ring-Attention segments the Query, Key, and Value (QKV) into smaller blocks, which can lead to a decrease in efficiency when using FlashAttention. Even with the communication and computation processes fully overlapped, the total execution time lags behind that of Ulysses. Furthermore, Ring-Attention utilizes asynchronous peer-to-peer communication, which not only has a lower bandwidth utilization compared to collective communication methods but also poses the risk of potential communication deadlocks in large-scale deployments.

LongContextAttention, also known as Unified Sequence Parallelism and Hybrid Sequence Parallelism

LongContextAttention is a unified sequence parallel , also known as hybrid sequence parallel ,that hybrid DeepSpeed-Ulysses-Attention and Ring-Attention therefore addressing the limitations of both methods.

1. Usage

Please refer to test/test_hybrid_qkvpacked_attn.py and test/test_hybrid_attn.py for usage.

In short, we take the zigzag ring attention implementation as an example:

  1. apply set_seq_parallel_pg to set the process group
  2. extract local tensors with zigzag_extract_local. We need reorder the input tokens or input tensors for load balance ring attention.
  3. then apply LongContextAttention(ring_impl_type="zigzag") as a drop-in replacement for Attention implementation.
from yunchang import (
    AsyncLongContextAttention,
    LongContextAttention,
    set_seq_parallel_pg,
    EXTRACT_FUNC_DICT
)
from yunchang.kernels import FlashAttentionImpl

sp_ulysses_degree = 2
sp_ring_degree = 4

# support world_size = 8
set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size)

# attn_type could be FA, FA3, TORCH.
longctx_attn = LongContextAttention(ring_impl_type="zigzag", attn_type=FlashAttentionImpl.FA)

# extract a local shard for the global Q, K, V.
local_q = EXTRACT_FUNC_DICT["zigzag"](
        Q, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
    ).detach().clone()
...


local_out = usp_attn(
        local_q,
        local_k,
        local_v,
        dropout_p=dropout_p,
        causal=True, # zigzag and stripe is load balance strategy for causal=True
        window_size=window_size,
        softcap=0.0,
        alibi_slopes=alibi_slopes,
        deterministic=deterministic,
        return_attn_probs=True,
    )

2. Installation

Option 1: pip install from pypi.

pip install yunchang (flash_attn >= 2.6.0)

pip install yunchang==0.2 (flash_attn < 2.6.0)

Apply FlashAttention V3: Since FA V3 is beta-released, you need to install FlashAttention V3 from source code.

Follow the FlashAttention beta-release to install V3 for NVIDIA Hopper GPUs.

We applied the Nov 10 2024 commit b443207c1fc4c98e4532aad4e88cfee1d590d996.

Option 2: build from local.

pip install .

Install for AMD GPU: install_amd.md

3.Test

torchrun --nproc_per_node 8 test/test_hybrid_qkvpacked_attn.py

4. Verified in Megatron-LM

The loss curves for Data Parallel (DP) and Unified Sequence Parallel (ulysses=2+ring=2) are closely aligned, as illustrated in the figure. This alignment confirms the accuracy of the unified sequence parallel.

When utilizing load-balance Ring Attention with a causal mask, it is essential to reorder the Query tensors using the EXTRACT_FUNC_DICT function.

In Megatron-LM, you can reorder the input tokens before feeding them into the model and apply the same reordering to the RoPE parameters. For detailed instructions, please refer to our paper.

For an example implementation, you can check out this PR, which integrates USP into a BAAI's Megatron-LM framework.

6. Benchmark

bash ./scripts/run_qkvpack_compare.sh

On an 8xA100 NVLink machine, the benchmark results are as follows:

On an 8xL20 PCIe machine and a 4xA100 PCIe machine, the benchmark results are as follows:

Some Conclusions:

  1. If the head number is enough, Ulysses outperforms Ring-Attention. The All-to-All communication of Ulysses is highly efficient within a single machine, with a very low overhead ratio. In contrast, Ring splits computation and communication, which increases the overall of computation time, and even with complete overlap, it is slower than Ulysses.

  2. QKV packed (LongContextAttentionQKVPacked) is better than the QKV no packed (LongContextAttention) version, with the difference becoming more pronounced as the sequence length decreases. MAQ and GQA can only use the no packed version.

  3. Among the variants of the Ring-Attention implementation, zigzag and stripe perform better than basic. Typically, zigzag is slightly better than stripe, but as the sequence length increases, the difference between zigzag and stripe becomes less noticeable. It is worth noting that both zigzag and stripe have specific layout requirements for the sequence dimension.

  4. Hybrid parallelism works well to heterogeneous network devices. For example, on an 8-GPU L20 setup, the optimal performance is achieved when ulysess_degree is set to 2 and ring_degree is set to 4.

7. Best Practice for 4D Parallelism

We analyze the impact of introducing Sequnce Parallelism to Data/ZeRO/Tensor/Pipeline Parallelism in a technique report, which can be found at here.

Some best practices are listed here:

  1. We suggest using Unified-SP in place of SP-Ring and SP-Ulysses, as it encompasses the capabilities of both while offering additional benefits.

  2. DP (data parallelism) vs SP: We suggest prioritizing the use of DP over SP if possible. Only when the batch size (bs) is insufficient for partitioning should one consider whether to employ SP

  3. Utilizing SP, it should always be used in conjunction wit ZeRO-1/2.

  4. Unified-SP has lower communication cost than Tensor Parallel with megatron-lm sequence parallelism (TP-sp)! You can use Unified-SP to replace TP for better speed. However, now switching TP (tensor parallelism) to SP+ZeRO2 cannot increase the sequence length in training. SP+ZeRO3 can train a similar sequence length as TP-sp. We suggest that SP may have an advantage over TP when employing GQA in terms of communication cost, as GQA can reduce the communication cost of SP without affecting TP.

  5. Setting a higher parallel degree of SP parallelism is possible, which may need to set a large ring degree when the head number is limited, to train a long sequence across a greater number of computational devices. But TP could not be set a high parallel.

8. Projects apply USP

I am honored that this repository has contributed to the following projects:

  1. xdit-project/xDiT
  2. NVlabs/VILA
  3. feifeibear/Odysseus-Transformer
  4. Ascend/AscendSpeed
  5. jzhang38/EasyContext
  6. FlagOpen/FlagScale
  7. zhiyuanhubj/LongRecipe
  8. NVIDIA/TransformerEngine
  9. xdit-project/mochi-xdit

9. Cite Us

@article{fang2024unified,
  title={USP: A Unified Sequence Parallelism Approach for Long Context Generative AI},
  author={Fang, Jiarui and Zhao, Shangchun},
  journal={arXiv preprint arXiv:2405.07719},
  year={2024}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

yunchang-0.4.1.tar.gz (32.8 kB view details)

Uploaded Source

Built Distribution

yunchang-0.4.1-py3-none-any.whl (47.5 kB view details)

Uploaded Python 3

File details

Details for the file yunchang-0.4.1.tar.gz.

File metadata

  • Download URL: yunchang-0.4.1.tar.gz
  • Upload date:
  • Size: 32.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.20

File hashes

Hashes for yunchang-0.4.1.tar.gz
Algorithm Hash digest
SHA256 6c3155356b0bc11d26e845270ec26076745168840695df22eaba24d491343a00
MD5 4b136f2cad0238d7a6e97a7c58d33844
BLAKE2b-256 45fa7ddc51dacbdc14662b14e5ad880825deaec72863cec98e44a4a3b5b67067

See more details on using hashes here.

File details

Details for the file yunchang-0.4.1-py3-none-any.whl.

File metadata

  • Download URL: yunchang-0.4.1-py3-none-any.whl
  • Upload date:
  • Size: 47.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.20

File hashes

Hashes for yunchang-0.4.1-py3-none-any.whl
Algorithm Hash digest
SHA256 097f9abe97637132183b9383dc96718bfd67b779fd31d8f69f62760b9e523ef7
MD5 eb91a16cd5037e5c0375dc1ec77ebd78
BLAKE2b-256 c6239b8cac9e692cf778fc6ad158abd202eecca856ffa2431039084201b3c329

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page