Skip to main content

Efficient computation library for linear attention.

Project description

An efficient Linear Attention Decoding package

1. installation

conda create -n leetDecoding python==3.9
conda activate leetDecoding
pip install leetDecoding

The code has been test under the following environment:

triton>=2.1.0
torch>=2.1.0
pycuda
pynvml
numpy<2

You can use the following command to install:

pip install triton==2.1.0
pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118
pip install pycuda
pip install pynvml
pip install numpy

2. usage

import torch
from leetDecoding.efficient_linear_decoding import causal_linear_decoder

# Create input tensor
Q = torch.randn(2,32,1024,128,device='cuda:0')
K = torch.randn(2,32,1024,128,device='cuda:0')
V = torch.randn(2,32,1024,128,device='cuda:0')

# Inference using causal_linear_decoder
output = causal_linear_decoder(Q,K,V)

# If you want to input a mask with weight that values are exp(-gamma), set the is_mask_weight: True and is_need_exp:True
gamma = torch.full((32,),0.5,device='cuda:0')
output = causal_linear_decoder(Q,K,V,is_mask_weight=True,gamma=gamma,is_need_exp=True)

# If you just want to input a mask with weight, set the is_mask_weight: True and is_need_exp:False
gamma = torch.full((32,),0.5,device='cuda:0')
output = causal_linear_decoder(Q,K,V,is_mask_weight=True,gamma=gamma,is_need_exp=False)

# If you want to use a specified methods, such as FleetAttention, set the attn-method: 'FleetAttention'
gamma = torch.full((32,),0.5,device='cuda:0')
output = causal_linear_decoder(Q,K,V,is_mask_weight=False,attn_method='FleetAttention')

3. acknowledgement

method Title Paper Code
causal_dot_product Fast Transformers with Clustered Attention arxiv code
Lighting Attention-2 Lightning Attention-2: A Free Lunch for Handling Unlimited Sequence Lengths in Large Language Models arxiv code

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

leetDecoding-0.0.3.tar.gz (29.1 kB view details)

Uploaded Source

File details

Details for the file leetDecoding-0.0.3.tar.gz.

File metadata

  • Download URL: leetDecoding-0.0.3.tar.gz
  • Upload date:
  • Size: 29.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.4

File hashes

Hashes for leetDecoding-0.0.3.tar.gz
Algorithm Hash digest
SHA256 0c380ecc61fe9444efbfcff0668b85963b8ff2ce54ed2dc17f42189aacd140a7
MD5 764110a8d911a0fef286c7abd3c00c15
BLAKE2b-256 67bbf25e4d311fc662de4eba1d2fd20c288df7ec48b18e7e28953c29975c9065

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page