Fast kernel for triangle self attetion.
Project description
TriFast
A Triton-based implementation of Fused Triangle Self Attention kernels. TriFast provides an optimized version triangle self attention, using the ideas of flash attention.
🚀 Features
- Memory Efficient: Achieves n² memory complexity (compared to n³ for pure PyTorch implementations)
- High Performance:
- ⚡ ~4x faster forward pass than the next fastest implementation (DS4S evoformer kernel)
- ⚡ ~2x faster backward pass than the next fastest implementation (DS4S evoformer kernel)
- Multiple Precision Support: Works with float32, bfloat16, and fp16 data types
- GPU Accelerated: Benchmarked on NVIDIA GPUs with excellent scaling properties
- Auto-tuning: Includes built-in kernel tuning capabilities to optimize for specific workloads
📊 Benchmarks
All benchmarks were performed on an NVIDIA GeForce RTX 3090 GPU using BFloat16 precision.
Forward Pass Performance
Runtime
Memory Usage
Backward Pass Performance
Runtime
Memory Usage
🛠️ Installation
pip install trifast
📖 Usage
Basic usage of the triangle attention function:
import torch
from trifast import triangle_attention
from trifast.utils import gen_tensors
# Generate tensors (query, key, value, bias, mask)
q, k, v, bias, mask = gen_tensors(n, d, h, use_mask=True, device=device, dtype=dtype, std=scale)
# Apply triangle self attention
out = triangle_attention(q, k, v, bias, mask)
⚙️ Auto-tuning
TriFast modifies triton.autotune to cache the best config to disk. This means the first time the kernel is run will generally be slower than subsequent times (for a given input shape).
Using the auto-tuner
The package provides a command-line script for auto-tuning:
# Basic usage
trifast-tune
# For a more extensive tuning process
TRIFAST_FORCE_TUNE=1 trifast-tune
# With custom parameters
trifast-tune --min-n 32 --max-n 2048 --dtype bfloat16 --h 4,8 --d 32,64
Auto-tuner Parameters
--min-n Minimum sequence length to tune (default: 16)
--max-n Maximum sequence length to tune (default: 1024)
--dtype PyTorch datatypes to use (comma-separated).
Options: float32, bfloat16, float16, all (default: bfloat16)
--h List of number of heads (comma-separated integers, e.g., "1,2,4") (default: 4)
--d List of dimensions (comma-separated integers, e.g., "16,32,64") (default: 32)
After tuning, the best configurations are cached to disk using platformdirs (e.g. under ~/.config/trifast/<version> on Linux).
🧠 How It Works
TriFast implements Triangle Self Attention using Triton to create optimized GPU kernels. It's essentially a Flash Attention applied to triangle self attention, resulting in significant performance gains compared to naive PyTorch implementations.
🔍 Implementation Details
This implementation draws inspiration from:
🗺️ Roadmap
- Explore performance optimizations for dq/db/dkv transposed operations
- Implement pipelined writes to global memory in backward kernels
👨💻 Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
📄 License
This project is licensed under the MIT License - see the LICENSE file for details.
🙏 Acknowledgements
- The FlagOpen team for their work on FlagAttention
- The Triton team for providing excellent documentation and tutorials
- The DS4S team for their evoformer kernel implementation used in benchmarking
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file trifast-0.1.13.dev1.tar.gz.
File metadata
- Download URL: trifast-0.1.13.dev1.tar.gz
- Upload date:
- Size: 428.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.6.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ba4d747d1dc1dc8807371c36d46b786056f8b53525f3d8aa8b5ac0fb41a505f7
|
|
| MD5 |
1dd8810f126a82acc7270612e99c6eb7
|
|
| BLAKE2b-256 |
8a82d18acf0e7a89edcd97a48e00aab877b6fdc3d88374329d5d856c151e57b2
|
File details
Details for the file trifast-0.1.13.dev1-py3-none-any.whl.
File metadata
- Download URL: trifast-0.1.13.dev1-py3-none-any.whl
- Upload date:
- Size: 16.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.6.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f7fc9be354ff3c1417be6a640ff92fcac394790f99293f9c2bc19987fc68cb62
|
|
| MD5 |
c7ae9c366458b60d76d54d8349d05137
|
|
| BLAKE2b-256 |
9276167040cfab59139bc013eae1744444a4e00bbe462789a4316b1cf4b4e704
|