CUDA and Triton implementations of Flash Attention with SoftmaxN.
Project description
Flash-Attention-Softmax-N
Flash attention with softmaxN. Attention is Off By One hypothesized that using softmax1 in the attention mechanism will reduce the number of outliers in the activations and weights of a transformer model.
🎯Efficent, Numerically-Stable Implementation of SoftmaxN: No more worrying about the non-trivial implementation of softmaxN. $$\text{softmax}_n(x_i) = \frac{\exp(x_i)}{n + \sum_j \exp(x_j)}$$
🚀 Multiple Attention Implementations, your choice: Whatever you're aiming for, we've got you covered with three Attention implementations In the spirit of the flash attention paper, further gains can be made by considering the whole attention function instead of just the softmaxN subfunction.
flash_attention_n
: recommended for integer values of n, uses CUDA on the backend if a GPU is availableflash_attention_n_triton
: recommended for non-integer values of n when a GPU is available, uses Tritonslow_attention_n
: flexible, torch-based implementation
Install
Simple installation
$ pip install flash-attention-softmax-n
Optionally install the Triton implementation
$ pip install flash-attention-softmax-n[triton]
$ pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
Usage
Feature / Function | flash_attention_n |
flash_attention_n_triton |
slow_attention_n |
---|---|---|---|
CPU-compatible? | Yes | No | Yes |
Real or Integer valued $n$ | Integer | Real | Real |
Datatype(s) natively supported on GPU | fp32, fp16, bf16 | fp16 (*see below) | fp32, fp16, bf16 |
Datatypes natively supported on CPU | fp32, bf16 | n/a | fp32, bf16 |
Dropout? | Yes | No | Yes |
Causal Mask? | Yes | only tested for $n \leq 10^{-3}$ | Yes |
Attention Bias (ALiBi) | Yes | No | No |
Attention Mask | Yes | No | Yes |
supports query.ndim < 4 |
No | No | Yes |
supports key.ndim < 4 and value.ndim < 4 |
Yes | No | Yes |
requries key.shape[-1] == value.shape[-1] |
No | Yes | No |
CUDA
The recommendation function to use for integer-values of n with or without a GPU.
You'll probably need an A100 to reap the full benefit though.
This implementation was inspired by x-transformers.
It uses torch.nn.functional.scaled_dot_product_attention
on the backend, which requires torch>=2.0.0
.
import torch
from flash_attention_softmax_n import flash_attention_n
softmax_n_param = 1
query = torch.randn((6, 1, 1024, 64))
key = torch.randn((6, 1152, 64))
value = torch.randn((6, 1152, 32))
attn = flash_attention_n(
query=query,
key=key,
value=value,
softmax_n_param=softmax_n_param,
scale=None,
dropout_p=0.,
attn_mask=None,
attn_bias=None,
is_causal=False
)
Triton
The recommended function to use when you want GPU acceleration and have a non-integer-valued n.
Note the Triton implementation has a more limited set of features compared to the CUDA version, see the above comparison table.
*To use datatypes other than fp16
first convert your input to fp16
and then convert the attention output back to your original datatype.
This is a generalization of OpenAI's Triton fused attention implementation.
Requires torch>=2.0.0
and triton>=2.0.0
.
import torch
from flash_attention_softmax_n import flash_attention_n_triton
softmax_n_param = 1.
query = torch.randn((6, 1, 1024, 64))
key = torch.randn((6, 1, 1152, 64))
value = torch.randn((6, 1, 1152, 64))
attn = flash_attention_n_triton(
query=query,
key=key,
value=value,
softmax_n_param=softmax_n_param,
scale=None,
is_causal=False
)
Slow Attention
Written in torch. Use this version when you have a real-valued n, and the Triton version is unavailable or doesn't have the feature(s) you need.
import torch
from flash_attention_softmax_n import slow_attention_n
softmax_n_param = 1.
query = torch.randn((6, 1024, 64))
key = torch.randn((6, 1152, 64))
value = torch.randn((6, 1152, 32))
attn = slow_attention_n(
query=query,
key=key,
value=value,
softmax_n_param=softmax_n_param,
scale=None,
dropout_p=0.,
attn_mask=None,
is_causal=False,
softmax_dtype=None
)
We also provide a torch implementation of softmaxN that can be used as a drop-in replacement for softmax.
import torch
from flash_attention_softmax_n import softmax_n
x = torch.rand((100, 100))
# y = torch.nn.functional.softmax(x, dim=-1, dtype=torch.float32)
y = softmax_n(x, dim=-1, dtype=torch.float32)
y1 = softmax_n(x, n=1.)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for flash-attention-softmax-n-0.1.3.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 61db521553f9f4696a14294608a50665d22e28605e88c89ada3818c71bdb439b |
|
MD5 | 773b6201834114c2320926334836fbc8 |
|
BLAKE2b-256 | ea22ccad6b7ffa8e981f73d732fdead7c2f94db7e1c94da9be59661947ee3c72 |
Hashes for flash_attention_softmax_n-0.1.3-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 41e31dbc2e51d237570ae80c840ae9e777d7d9d1ef65e288ad828662d86f21cf |
|
MD5 | f381f89b9930e5f9f5b2a5ef70fa1c98 |
|
BLAKE2b-256 | 0e2b5d9a20a0cf2ae141033fcaefa6c69e1f0864176228cb45f03ea60ea3f458 |