Skip to main content

Fast kernel for triangle self attetion.

Project description

TriFast

License: MIT

A Triton-based implementation of Fused Triangle Self Attention kernels. TriFast provides an optimized version triangle self attention, using the ideas of flash attention.

🚀 Features

  • Memory Efficient: Achieves n² memory complexity (compared to n³ for pure PyTorch implementations)
  • High Performance:
    • ⚡ ~4x faster forward pass than the next fastest implementation (DS4S evoformer kernel)
    • ⚡ ~2x faster backward pass than the next fastest implementation (DS4S evoformer kernel)
  • Multiple Precision Support: Works with float32, bfloat16, and fp16 data types
  • GPU Accelerated: Benchmarked on NVIDIA GPUs with excellent scaling properties
  • Auto-tuning: Includes built-in kernel tuning capabilities to optimize for specific workloads

📊 Benchmarks

All benchmarks were performed on an NVIDIA GeForce RTX 3090 GPU using BFloat16 precision.

Forward Pass Performance

Runtime

TSA forward runtime

Memory Usage

TSA forward memory

Backward Pass Performance

Runtime

TSA backward runtime

Memory Usage

TSA backward memory

🛠️ Installation

pip install trifast

📖 Usage

Basic usage of the triangle attention function:

import torch
from trifast import triangle_attention
from trifast.utils import gen_tensors

# Generate tensors (query, key, value, bias, mask)
q, k, v, bias, mask = gen_tensors(n, d, h, use_mask=True, device=device, dtype=dtype, std=scale)

# Apply triangle self attention
out = triangle_attention(q, k, v, bias, mask)

⚙️ Auto-tuning

TriFast modifies triton.autotune to cache the best config to disk. This means the first time the kernel is run will generally be slower than subsequent times (for a given input shape).

Using the auto-tuner

The package provides a command-line script for auto-tuning:

# Basic usage
trifast-tune

# For a more extensive tuning process
TRIFAST_FORCE_TUNE=1 trifast-tune

# With custom parameters
trifast-tune --min-n 32 --max-n 2048 --dtype bfloat16 --h 4,8 --d 32,64

Auto-tuner Parameters

--min-n          Minimum sequence length to tune (default: 16)
--max-n          Maximum sequence length to tune (default: 1024)
--dtype          PyTorch datatypes to use (comma-separated).
                 Options: float32, bfloat16, float16, all (default: bfloat16)
--h              List of number of heads (comma-separated integers, e.g., "1,2,4") (default: 4)
--d              List of dimensions (comma-separated integers, e.g., "16,32,64") (default: 32)

After tuning, the best configurations are cached to disk using platformdirs (e.g. under ~/.config/trifast/<version> on Linux).

🧠 How It Works

TriFast implements Triangle Self Attention using Triton to create optimized GPU kernels. It's essentially a Flash Attention applied to triangle self attention, resulting in significant performance gains compared to naive PyTorch implementations.

🔍 Implementation Details

This implementation draws inspiration from:

🗺️ Roadmap

  • Explore performance optimizations for dq/db/dkv transposed operations
  • Implement pipelined writes to global memory in backward kernels

👨‍💻 Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

📄 License

This project is licensed under the MIT License - see the LICENSE file for details.

🙏 Acknowledgements

  • The FlagOpen team for their work on FlagAttention
  • The Triton team for providing excellent documentation and tutorials
  • The DS4S team for their evoformer kernel implementation used in benchmarking

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trifast-0.1.13.tar.gz (428.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

trifast-0.1.13-py3-none-any.whl (16.2 kB view details)

Uploaded Python 3

File details

Details for the file trifast-0.1.13.tar.gz.

File metadata

  • Download URL: trifast-0.1.13.tar.gz
  • Upload date:
  • Size: 428.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.6.17

File hashes

Hashes for trifast-0.1.13.tar.gz
Algorithm Hash digest
SHA256 e3918dc81f4ae75c0fd2218bae5b2055e18242e4067ac34c21c64c2c957648ad
MD5 2d6d5db0555e284d7874338813efac52
BLAKE2b-256 6c1b6d384e72cad1512493b8cb6c074ee7860e4a9040d3a0c39336101b972a3b

See more details on using hashes here.

File details

Details for the file trifast-0.1.13-py3-none-any.whl.

File metadata

  • Download URL: trifast-0.1.13-py3-none-any.whl
  • Upload date:
  • Size: 16.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.6.17

File hashes

Hashes for trifast-0.1.13-py3-none-any.whl
Algorithm Hash digest
SHA256 e96fc0224b85e98802b04c495380198387110e3e712ce2ca7d2e672a27e10496
MD5 f29c0bb6b0197b6eefd6e147b1233696
BLAKE2b-256 11b35b2961811e4e47fcad1e3eb2fab6cc96285ebaf03768304ef84980fb3dcf

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page