Skip to main content

Fast Triton-based implementations for RWKV

Project description

RWKV-FLA

hf_model Discord

This repo aims at providing Triton kernel for RWKV models. RWKV is a brand new network architecture that integrates the advantages of transformers and RNNs, and can be used for a variety of natural language processing tasks. Also, RWKV is the state-of-the-art RNN model.

This project implements multi-level state chain differentiation for RWKV6, efficient differentiation of all input parameters, while maintaining high computational precision (both bf16 and fp32). Currently, it does not consider pure fp16 variants such as RWKV x060c.

Some benchmarks (chunk_rwkv6(fla) vs CUDA kernel)

Since the project is under active development, the calculated times may differ.

fused_recurrent_rwkv6 will be much slower!

Test Case Implementation Forward Time Backward Time
Test Case 1: B=8, T=4096, C=4096, HEAD_SIZE=64 CUDA BF16 9.69 ms 46.41 ms
FLA BF16 13.06 ms 40.79 ms
Test Case 2: B=32, T=4096, C=4096, HEAD_SIZE=64 CUDA BF16 32.80 ms 148.05 ms
FLA BF16 50.17 ms 162.42 ms
Test Case 3: B=8, T=4096, C=4096, HEAD_SIZE=128 CUDA BF16 12.01 ms 65.68 ms
FLA BF16 14.18 ms 51.36 ms
Test Case 4: B=8, T=4096, C=4096, HEAD_SIZE=256 CUDA BF16 40.82 ms 225.59 ms
FLA BF16 19.34 ms 72.03 ms
Test Case 5: B=16, T=4096, C=4096, HEAD_SIZE=128 CUDA BF16 20.56 ms 109.76 ms
FLA BF16 27.72 ms 102.35 ms
Test Case 6: B=16, T=4096, C=4096, HEAD_SIZE=256 CUDA BF16 61.54 ms 344.85 ms
FLA BF16 38.24 ms 144.12 ms
from fla.ops.rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6, native_recurrent_rwkv6

@torch.compile
def run_fla_kernel(B, T, C, H, r, k, v, w, u, s):
    r = r.view(B,T,H,-1).transpose(1,2)
    k = k.view(B,T,H,-1).transpose(1,2)
    v = v.view(B,T,H,-1).transpose(1,2)
    # u can be 3d or 2d (B, H, -1) or just (H, -1) to save VRAM
    w = -torch.exp(w.view(B,T,H,-1).transpose(1,2))
    o, final_state = chunk_rwkv6(r, k, v, w, u=u, scale=1.0, initial_state=s, output_final_state=True)
    return o.transpose(1,2).reshape(B,T,C), final_state

This repo aims at providing a collection of efficient Triton-based implementations for state-of-the-art linear attention models. Any pull requests are welcome!

image

Installation

The following requirements should be satisfied

As fla is actively developed now, you should alwayd check for latest version pip install --upgrade rwkv-fla triton

Or you can install if with pip install rwkv-fla[cuda], pip install rwkv-fla[xpu], pip install rwkv-fla[rocm]

If you do need to use fla ops/modules and contemplate further explorations, an alternative way is to install the package from source

pip install -U git+https://github.com/TorchRWKV/flash-linear-attention

or

pip install -U git+https://gitee.com/uniartisan2018/flash-linear-attention

or manage fla with submodules

git submodule add https://github.com/TorchRWKV/flash-linear-attention.git 3rdparty/rwkv-fla
ln -s 3rdparty/rwkv-fla/fla fla

[!CAUTION] If you're not working with Triton v2.2 or its nightly release, it's important to be aware of potential issues with the FusedChunk implementation, detailed in this issue. You can run the test python tests/test_fused_chunk.py to check if your version is affected by similar compiler problems. While we offer some fixes for Triton<=2.1, be aware that these may result in reduced performance.

For both Triton 2.2 and earlier versions (up to 2.1), you can reliably use the Chunk version (with hidden states materialized into HBMs). After careful optimization, this version generally delivers high performance in most scenarios.

Acknowledgments

The rwkv-fla project is a fork of the fla project. We extend our sincere gratitude to the original maintainers for their tremendous efforts and contributions. This project builds upon the work described in:

@software{yang2024fla,
  title  = {FLA: A Triton-Based Library for Hardware-Efficient Implementations of Linear Attention Mechanism},
  author = {Yang, Songlin and Zhang, Yu},
  url    = {https://github.com/sustcsonglin/flash-linear-attention},
  month  = jan,
  year   = {2024}
}

Their innovative work and expertise laid the foundation for the development of rwkv-fla.

Models

Date Model Title Paper Code FLA impl
2023-07 RetNet (@MSRA@THU) Retentive network: a successor to transformer for large language models [arxiv] [official] [RetNet] code
2023-12 GLA (@MIT@IBM) Gated Linear Attention Transformers with Hardware-Efficient Training [arxiv] [official] code
2023-12 Based (@Stanford@Hazyresearch) An Educational and Effective Sequence Mixer [blog] [official] code
2024-01 Rebased Linear Transformers with Learnable Kernel Functions are Better In-Context Models [arxiv] [official] code
2021-02 Delta Net Linear Transformers Are Secretly Fast Weight Programmers [arxiv] [official] code
2023-09 Hedgehog (@HazyResearch) The Hedgehog & the Porcupine: Expressive Linear Attentions with Softmax Mimicry openreview code
2023-10 PolySketchFormer (@CMU@Google) Fast Transformers via Sketching Polynomial Kernels arxiv TODO
2023-07 TransnormerLLM A Faster and Better Large Language Model with Improved TransNormer (@Shanghai AI Lab) openreview arxiv [official] [Lightning2] TODO
2023-05 RWKV-v4 (@BlinkDL) Reinventing RNNs for the Transformer Era arxiv [official] TODO
2023-10 GateLoop Fully Data-Controlled Linear Recurrence for Sequence Modeling openreview arxiv [official] [jax] TODO
2021-10 ABC (@UW) Attention with Bounded-memory Control arxiv code
2023-09 VQ-transformer Linear-Time Transformers via Vector Quantization arxiv [official] TODO
2023-09 HGRN Hierarchically Gated Recurrent Neural Network for Sequence Modeling openreview [official] code
2024-04 HGRN2 HGRN2: Gated Linear RNNs with State Expansion arxiv [official] code
2024-04 RWKV6 Eagle and Finch: RWKV with Matrix-Valued States and Dynamic Recurrence arxiv [official] code
2024-06 Samba Samba: Simple Hybrid State Space Models for Efficient Unlimited Context Language Modeling arxiv [official] code
2024-05 Mamba2 Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality arxiv [official] code

Usage

Token Mixing

We provide "token mixing" linear attention layers in fla.layers for you to use. You can replace the standard multihead attention layer in your model with other linear attention layers. Example usage is as follows:

>>> import torch
>>> from fla.layers import MultiScaleRetention
>>> batch_size, num_heads, seq_len, hidden_size,  = 32, 4, 2048, 1024
>>> device, dtype = 'cuda:0', torch.bfloat16
>>> retnet = MultiScaleRetention(hidden_size=hidden_size, num_heads=num_heads).to(device=device, dtype=dtype)
>>> x = torch.randn(batch_size, seq_len, hidden_size).to(device=device, dtype=dtype)
>>> y, *_ = retnet(x)
>>> y.shape
torch.Size([32, 2048, 1024])

We provide the implementations of models that are compatible with 🤗 Transformers library. Here's an example of how to initialize a GLA model from the default configs in fla:

>>> from fla.models import GLAConfig
>>> from transformers import AutoModel
>>> config = GLAConfig()
>>> config
GLAConfig {
  "attn_mode": "fused_chunk",
  "bos_token_id": 1,
  "clamp_min": null,
  "conv_size": 4,
  "eos_token_id": 2,
  "expand_k": 0.5,
  "expand_v": 1,
  "fuse_cross_entropy": true,
  "fuse_norm": true,
  "hidden_act": "swish",
  "hidden_ratio": 4,
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": null,
  "max_position_embeddings": 2048,
  "model_type": "gla",
  "num_heads": 4,
  "num_hidden_layers": 24,
  "rms_norm_eps": 1e-06,
  "share_conv_kernel": true,
  "tie_word_embeddings": false,
  "transformers_version": "4.39.1",
  "use_cache": true,
  "use_gk": true,
  "use_gv": false,
  "use_short_conv": false,
  "vocab_size": 32000
}

>>> AutoModel.from_config(config)
GLAModel(
  (embed_tokens): Embedding(32000, 2048)
  (layers): ModuleList(
    (0-23): 24 x GLABlock(
      (attn_norm): RMSNorm()
      (attn): GatedLinearAttention(
        (gate_fn): SiLU()
        (q_proj): Linear(in_features=2048, out_features=1024, bias=False)
        (k_proj): Linear(in_features=2048, out_features=1024, bias=False)
        (v_proj): Linear(in_features=2048, out_features=2048, bias=False)
        (g_proj): Linear(in_features=2048, out_features=2048, bias=False)
        (gk_proj): Sequential(
          (0): Linear(in_features=2048, out_features=16, bias=False)
          (1): Linear(in_features=16, out_features=1024, bias=True)
        )
        (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        (g_norm_swish_gate): FusedRMSNormSwishGate()
      )
      (mlp_norm): RMSNorm()
      (mlp): GLAMLP(
        (gate_proj): Linear(in_features=2048, out_features=11264, bias=False)
        (down_proj): Linear(in_features=5632, out_features=2048, bias=False)
        (act_fn): SiLU()
      )
    )
  )
  (norm): RMSNorm()
)

Generation

Upon successfully pretraining a model, it becomes accessible for generating text using the 🤗 text generation APIs. In the following, we give a generation example:

>>> import fla
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> name = 'fla-hub/gla-1.3B-100B'
>>> tokenizer = AutoTokenizer.from_pretrained(name)
>>> model = AutoModelForCausalLM.from_pretrained(name).cuda()
>>> input_prompt = "Power goes with permanence. Impermanence is impotence. And rotation is castration."
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.cuda()
>>> outputs = model.generate(input_ids, max_length=64)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

We also provide a simple script here for benchmarking the generation speed. Simply run it by:

$ python -m benchmarks.benchmark_generation \
  --path 'fla-hub/gla-1.3B-100B' \
  --repetition_penalty 2. \
  --prompt="Hello everyone, I'm Songlin Yang"

Prompt:
Hello everyone, I'm Songlin Yang
Generated:
Hello everyone, I'm Songlin Yang.
I am a 20 year old girl from China who is currently studying in the United States of America for my Master degree and also working as an English teacher at school here on campus since last summer (1st semester). My main goal to be able do well with this course so that we can have

Prompt length: 10, generation length: 64
Total prompt processing + decoding time: 4593ms

All of the pretrained models currently available can be found in fla-hub.

>>> from huggingface_hub import list_models
>>> for model in list_models(author='fla-hub'): print(model.id)

Evaluations

The lm-evaluation-harness library allows you to easily perform (zero-shot) model evaluations. Follow the steps below to use this library:

  1. Install lm_eval following their instructions.

  2. Run evaluation with:

$ PATH='fla-hub/gla-1.3B-100B'
$ python -m evals.harness --model hf \
    --model_args pretrained=$PATH,dtype=bfloat16 \
    --tasks wikitext,lambada_openai,piqa,hellaswag,winogrande,arc_easy,arc_challenge,boolq,sciq,copa,openbookqa \
    --batch_size 64 \
    --num_fewshot 0 \
    --device cuda \
    --show_config

We've made fla compatible with hf-style evaluations, you can call evals.harness to finish the evaluations. Running the command above will provide the task results reported in the GLA paper.

[!Tip] If you are using lm-evaluation-harness as an external library and can't find (almost) any tasks available, before calling lm_eval.evaluate() or lm_eval.simple_evaluate(), simply run the following to load the library's stock tasks!

>>> from lm_eval.tasks import TaskManager; TaskManager().initialize_tasks()

Benchmarks

We compared our Triton-based RetNet implementation with CUDA-based FlashAttention2, using a batch size of 8, 32 heads, and a head dimension of 128, across different sequence lengths. These tests were conducted on a single A100 80GB GPU, as illustrated in the following graph

# you might have to first install `fla` to enable its import via `pip install -e .`
$ python benchmark_retention.py
Performance:
   seq_len  fused_chunk_fwd  chunk_fwd  parallel_fwd  fused_chunk_fwdbwd  chunk_fwdbwd  parallel_fwdbwd  flash_fwd  flash_fwdbwd
0    128.0         0.093184   0.185344      0.067584            1.009664      1.591296         1.044480   0.041984      0.282624
1    256.0         0.165888   0.219136      0.126976            1.024000      1.596928         1.073152   0.074752      0.413696
2    512.0         0.308224   0.397312      0.265216            1.550336      1.603584         1.301504   0.156672      0.883712
3   1024.0         0.603136   0.747520      0.706560            3.044864      3.089408         3.529728   0.467968      2.342912
4   2048.0         1.191424   1.403904      2.141184            6.010880      6.059008        11.009024   1.612800      7.135232
5   4096.0         2.377728   2.755072      7.392256           11.932672     11.938816        37.792770   5.997568     24.435200
6   8192.0         4.750336   5.491712     26.402817           23.759359     23.952385       141.014023  22.682114     90.619904
7  16384.0         9.591296  10.870784    101.262337           47.666176     48.745472       539.853821  91.346947    346.318848

Performance

Different forms of linear attention

Please refer to Sectiton 2.3 of GLA paper for hardware considerations of different forms of linear attention.

  • Parallel: Self-attention-styled computation in $O(L^2)$ time with sequence parallelism.
  • FusedRecurrent: Recurrent computation in $O(L)$ time. Hidden states are computed on-the-fly in shared memory without any materialization to global memory (see Algorithm1 of this paper for more details!). This saves a lot of I/O cost and should be a strong baseline for speed comparison.
  • FusedChunk: Chunkwise computation in $O(LC)$ time where $C$ is the chunk size. Hidden states are computed on-the-fly without any materialization to global memory likewise FusedRecurrent. This version is usually better than FusedReuccurent because tensor cores can be used for sequence level "reduction", whilst FusedRecurrent cannot use tensor cores at all. Note that there is no sequence level parallelism in this implementation, so this impl is not suitable for the very small batch size setting. Should be more memory efficient than ParallelChunk.
  • ParallelChunk: Chunkwise computation with sequence parallelism. Need to materialize hidden states to global memory for each chunk. $C$ is needed to set properly to achieve good performance because when $C$ is small there are too many hidden states to load/store to global memory; and when $C$ is too large the FLOPs are high. Recommened $C$ is [64, 128, 256]

Citation

If you find this repo useful, please consider citing our works:

@article{yang2024delta,
  title   = {Parallelizing Linear Transformers with the Delta Rule over Sequence Length}, 
  author  = {Songlin Yang and Bailin Wang and Yu Zhang and Yikang Shen and Yoon Kim},
  journal = {arXiv preprint arXiv:2406.06484},
  year    = {2024},
}

@article{yang2023gated,
  title   = {Gated Linear Attention Transformers with Hardware-Efficient Training},
  author  = {Yang, Songlin and Wang, Bailin and Shen, Yikang and Panda, Rameswar and Kim, Yoon},
  journal = {arXiv preprint arXiv:2312.06635},
  year    = {2023}
}

@software{yang2024fla,
  title  = {FLA: A Triton-Based Library for Hardware-Efficient Implementations of Linear Attention Mechanism},
  author = {Yang, Songlin and Zhang, Yu},
  url    = {https://github.com/sustcsonglin/flash-linear-attention},
  month  = jan,
  year   = {2024}
}

@article{yang2024delta,
  title   = {Parallelizing Linear Transformers with the Delta Rule over Sequence Length}, 
  author  = {Songlin Yang and Bailin Wang and Yu Zhang and Yikang Shen and Yoon Kim},
  journal = {arXiv preprint arXiv:2406.06484},
  year    = {2024},
}

@article{zhang2024gsa,
  title   = {Gated Slot Attention for Efficient Linear-Time Sequence Modeling}, 
  author  = {Yu Zhang and Songlin Yang and Ruijie Zhu and Yue Zhang and Leyang Cui and Yiqiao Wang and Bolun Wang and Freda Shi and Bailin Wang and Wei Bi and Peng Zhou and Guohong Fu},
  journal = {arXiv preprint arXiv:2409.07146},
  year    = {2024},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rwkv-fla-0.1.202409231131.tar.gz (194.9 kB view hashes)

Uploaded Source

Built Distribution

rwkv_fla-0.1.202409231131-py3-none-any.whl (299.9 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page