Efficient computation library for linear attention.
Project description
An efficient Linear Attention Decoding package
1. installation
conda create -n leetDecoding python==3.9
conda activate leetDecoding
pip install leetDecoding
The code has been test under the following environment:
triton>=2.1.0
torch>=2.1.0
pycuda
pynvml
numpy<2
You can use the following command to install:
pip install triton==2.1.0
pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118
pip install pycuda
pip install pynvml
pip install numpy
2. usage
import torch
from leetDecoding.efficient_linear_decoding import causal_linear_decoder
# Create input tensor
Q = torch.randn(2,32,1024,128,device='cuda:0')
K = torch.randn(2,32,1024,128,device='cuda:0')
V = torch.randn(2,32,1024,128,device='cuda:0')
# Inference using causal_linear_decoder
output = causal_linear_decoder(Q,K,V)
# If you want to input a mask with weight that values are exp(-gamma), set the is_mask_weight: True and is_need_exp:True
gamma = torch.full((32,),0.5,device='cuda:0')
output = causal_linear_decoder(Q,K,V,is_mask_weight=True,gamma=gamma,is_need_exp=True)
# If you just want to input a mask with weight, set the is_mask_weight: True and is_need_exp:False
gamma = torch.full((32,),0.5,device='cuda:0')
output = causal_linear_decoder(Q,K,V,is_mask_weight=True,gamma=gamma,is_need_exp=False)
# If you want to use a specified methods, such as FleetAttention, set the attn-method: 'FleetAttention'
gamma = torch.full((32,),0.5,device='cuda:0')
output = causal_linear_decoder(Q,K,V,is_mask_weight=False,attn_method='FleetAttention')
3. acknowledgement
method | Title | Paper | Code |
---|---|---|---|
causal_dot_product | Fast Transformers with Clustered Attention | arxiv | code |
Lighting Attention-2 | Lightning Attention-2: A Free Lunch for Handling Unlimited Sequence Lengths in Large Language Models | arxiv | code |
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
leetDecoding-0.0.3.tar.gz
(29.1 kB
view details)
File details
Details for the file leetDecoding-0.0.3.tar.gz
.
File metadata
- Download URL: leetDecoding-0.0.3.tar.gz
- Upload date:
- Size: 29.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0c380ecc61fe9444efbfcff0668b85963b8ff2ce54ed2dc17f42189aacd140a7 |
|
MD5 | 764110a8d911a0fef286c7abd3c00c15 |
|
BLAKE2b-256 | 67bbf25e4d311fc662de4eba1d2fd20c288df7ec48b18e7e28953c29975c9065 |