Skip to main content

Fast kernel for triangle self attetion.

Project description

Fused Triangle Self Attention kernel, written in triton. Basically flash attention, but for triangle self attention. Implementation heavily inspired by FlagAttention and the triton fused attention tutorial.

  • n^2 memory complexity (vs n^3 for pure pytorch).
  • Faster (~2x) backward pass than next fastest implementation I could find (DS4S evoformer kernel).
  • Faster (~4x) forward pass than next fastest implementation I could find (DS4S evoformer kernel).
  • As far as I can tell, faster than naieve implementation.

Plots

All done on a 3090 in bfloat16.

Forward

TSA forward runtime TSA forward memory

Backward TSA backward runtime TSA backward memory

Todos:

  • [] Try to train a model with it.
  • [] Can we perform and of dq/db/dkv transposed?
  • [] Rewrite autotuner

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trifast-0.1.7.tar.gz (15.8 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

trifast-0.1.7-py3-none-any.whl (23.6 kB view details)

Uploaded Python 3

File details

Details for the file trifast-0.1.7.tar.gz.

File metadata

  • Download URL: trifast-0.1.7.tar.gz
  • Upload date:
  • Size: 15.8 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.5.21

File hashes

Hashes for trifast-0.1.7.tar.gz
Algorithm Hash digest
SHA256 f9013be6e478b497786e4217fafbb1deb6e9f839f8fe3d7b616a19ffdfc36544
MD5 6b81da083544a370fb8e6d0efa5c862e
BLAKE2b-256 81fc5fe2a913142a3d7fd7efa553696bd9b82f56b4f95dc08fb9af7335917c98

See more details on using hashes here.

File details

Details for the file trifast-0.1.7-py3-none-any.whl.

File metadata

  • Download URL: trifast-0.1.7-py3-none-any.whl
  • Upload date:
  • Size: 23.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.5.21

File hashes

Hashes for trifast-0.1.7-py3-none-any.whl
Algorithm Hash digest
SHA256 c8b513e24826d04ea445a0f87670b9cb68db86d9c998d443823c23c72bb87682
MD5 cec4fbb5e6a0a8f1132304d85a396b90
BLAKE2b-256 f8c6eaf82d895a220608751f4b3ae4794f8406916968c5abf5ed54630dd1f097

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page