Skip to main content

FlashRNN: Optimizing Traditional RNNs on Modern Hardware

Project description

FlashRNN: Optimizing Traditional RNNs on Modern Hardware

Korbinian Pöppel, Maximilian Beck, Sepp Hochreiter

Intro

FlashRNN implements traditional RNNs like LSTMs, GRUs and Elman networks as well as the recent sLSTM architecture in CUDA and Triton. In contrary to common modern sequence models they have state tracking capabilities (Merrill et al., 2024). All of them are of the basic recurrent structure with input $\mathbf{x}^{(n)}_t$, bias $\mathbf{b}^{(n)}$, recurrent matrix $\mathbf{R}^{(n)}$ :

$$ \mathbf{g}^{(n)}_{t} = \mathbf{R}^{(n)} \ \mathbf{s}^{(0)}_{t-1} + \mathbf{x}^{(n)}_{t} + \mathbf{b}^{(n)} \ \mathbf{y}^{(m)}_t = \mathcal{P}^{(m)}\left( \left( \mathbf{s}^{(m')}_{t-1} \right)_{m'} , \left( \mathbf{g}^{(n)}_{t} \right)_{n} \right) $$

Typically the inputs are modified with a linear layer which is omitted here for flexibility (it would look like $\mathbf{x}^{n}_t = \mathbf{W}^{n} \mathbf{x'}_t$). This operation can be parallelized along the sequence dimension in contrary to the recurrent part, \ It employs a multi-head structure, which is equivalent to having a block-diagonal recurrent matrix. The hidden state and gate vectors of hidden dimension $d$ are split into heads of head dimension $d_{head}$.

For the fused triton backend, heads are limited to small head dimensions $d_{head} \leq 64$. For the CUDA backend there are two versions. The basic cuda one that alternates between recurrent matrix multiplication the non-linear pointwise function $\mathcal{P}$ application. This version is not limited in head dimension $d_{head}$. The second is a cuda_fused version, which fuses matrix multiplication with point-wise non-linearity into one CUDA kernel using wmma instructions and custom caching on SRAM / registers (similar to FlashAttention (Dao et al., 2022), but with a different focus here). Since the recurrent matrix $\mathbf{R}$ and biases $\mathbf{b}$ are used for for every time step, they are customly cached in registers and SRAM, enabling a $2 \times$ to $5 \times$ speedup over the alternating option.

Speed comparison

speed_comparison

Installation

To install FlashRNN, simply use:

pip install flashrnn

Your hardware needs to support CUDA Compute Capability $8.0$ or later. Make sure, you have an up to date g++ compiler installed. We recommend to use conda with an environment derived from the provided environment_pt240cu124.yaml:

conda env create -n flashrnn -f environment_pt240cu124.yaml

Using FlashRNN

FlashRNN employs a functional structure, none of the parameters are tied to the flashrnn function. To apply it simply use:

import torch
from flashrnn import flashrnn

device = torch.device('cuda')
dtype = torch.bfloat16
B = 8        # batch size
T = 1024     # sequence length
N = 3        # number of heads
D = 256      # head dimension
G = 4        # number of gates / pre-activations for LSTM example
S = 2        # number of states

Wx = torch.randn([B, T, G, N, D], device=device, dtype=dtype, requires_grad=True)
R = torch.randn([G, N, D, D], device=device, dtype=dtype, requires_grad=True)
b = torch.randn([G, N, D], device=device, dtype=dtype, requires_grad=True)
states_initial = torch.randn([S, B, 1, N, D], device=device, dtype=dtype, requires_grad=True)

# available functions
# lstm, gru, elman, slstm

# available backend
# cuda_fused, cuda, triton and vanilla

states, last_states = flashrnn(Wx, R, b, states=states_initial, function="lstm", backend="cuda_fused")

# for LSTM the hidden h state is the first of [h, c]
hidden_state = states[0]

Acknowledgement

We thank Thomas Schmied and Pieter-Jan Hoedt for valuable feedback.

Cite as

@misc{pöppel2024flashrnnoptimizingtraditionalrnns,
      title={FlashRNN: Optimizing Traditional RNNs on Modern Hardware}, 
      author={Korbinian Pöppel and Maximilian Beck and Sepp Hochreiter},
      year={2024},
      eprint={2412.07752},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2412.07752}, 
}

License

NXAI Community License (see LICENSE file)

Citations

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flashrnn-1.0.0.tar.gz (96.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flashrnn-1.0.0-py3-none-any.whl (138.9 kB view details)

Uploaded Python 3

File details

Details for the file flashrnn-1.0.0.tar.gz.

File metadata

  • Download URL: flashrnn-1.0.0.tar.gz
  • Upload date:
  • Size: 96.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.0.1 CPython/3.12.8

File hashes

Hashes for flashrnn-1.0.0.tar.gz
Algorithm Hash digest
SHA256 79e714f2674cde1bd4820b4bd9c65c79a3414fa1d2de239147aad2f09970b689
MD5 3a74dfed3f77fc23c5fc227f1a501e9b
BLAKE2b-256 a2477f6940438df53772c0014b61e48e7a897d81c2980aad5ec583f8ff443a54

See more details on using hashes here.

Provenance

The following attestation bundles were made for flashrnn-1.0.0.tar.gz:

Publisher: python-publish.yml on NX-AI/flashrnn

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file flashrnn-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: flashrnn-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 138.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.0.1 CPython/3.12.8

File hashes

Hashes for flashrnn-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 13bd14779c8ab58ae0fa66e5c3d850200163254ede8f43f35d91bd6ba4953603
MD5 803b32675e80c8150e962905fa1288c6
BLAKE2b-256 0ea849a7c16606a3b9b1fb45b931e1e7635b5d744fce4e8de9759d3ca4dc1019

See more details on using hashes here.

Provenance

The following attestation bundles were made for flashrnn-1.0.0-py3-none-any.whl:

Publisher: python-publish.yml on NX-AI/flashrnn

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page