Attention with QK distance using KL divergence
Project description
kl-div-attention
Just another attention variant, where the query key distance uses KL divergence, for testing Gitlab workflow
Install
pip install kl-div-attention
Usage
import torch
from kl_div_attention import KLDivAttention
attn = KLDivAttention(
dim = 512,
heads = 8,
dim_head = 64,
causal = True,
prenorm = True,
fused_mode = 'flash' # use fused triton flash attention
).cuda()
tokens = torch.randn(1, 1024, 512).cuda()
out = attn(tokens) + tokens
assert out.shape == tokens.shape
out.sum().backward()
Training
You can train a small transformer on Enwik8 with KL-divergence attention using the provided training script.
Eg. to use the fused Triton kernel
uv run train_enwik8.py --fused-mode flash
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file kl_div_attention-0.1.2.tar.gz.
File metadata
- Download URL: kl_div_attention-0.1.2.tar.gz
- Upload date:
- Size: 36.6 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.19
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b21af8bc92d44d4d3bd937a8c7f9dd0790aa043eb9b97984a36e6546f95626af
|
|
| MD5 |
6dab0ac2fc0f422ffa133ae4b6f58404
|
|
| BLAKE2b-256 |
c5d1342bf4e4f700f650e479e51f820c528bdcfb20cc815826771287f9470146
|
File details
Details for the file kl_div_attention-0.1.2-py3-none-any.whl.
File metadata
- Download URL: kl_div_attention-0.1.2-py3-none-any.whl
- Upload date:
- Size: 11.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.19
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
96598221e66dec82addb003c40730dd71ef8b6298b645b24c9cd2a3fe6ed14ae
|
|
| MD5 |
dfd913cf3239afef0730804acf20c6f9
|
|
| BLAKE2b-256 |
dbb610196b780bf6e9793f334b48b94646c4ee16ca6f218915eb4bb48ed02527
|