Skip to main content

Fast kernel for triangle self attetion.

Project description

TriFast

License: MIT

A Triton-based implementation of Fused Triangle Self Attention kernels. TriFast provides an optimized version triangle self attention, using the ideas of flash attention.

🚀 Features

  • Memory Efficient: Achieves n² memory complexity (compared to n³ for pure PyTorch implementations)
  • High Performance:
    • ⚡ ~4x faster forward pass than the next fastest implementation (DS4S evoformer kernel)
    • ⚡ ~2x faster backward pass than the next fastest implementation (DS4S evoformer kernel)
  • Multiple Precision Support: Works with float32, bfloat16, and fp16 data types
  • GPU Accelerated: Benchmarked on NVIDIA GPUs with excellent scaling properties
  • Auto-tuning: Includes built-in kernel tuning capabilities to optimize for specific workloads

📊 Benchmarks

All benchmarks were performed on an NVIDIA GeForce RTX 3090 GPU using BFloat16 precision.

Forward Pass Performance

Runtime

TSA forward runtime

Memory Usage

TSA forward memory

Backward Pass Performance

Runtime

TSA backward runtime

Memory Usage

TSA backward memory

🛠️ Installation

pip install trifast

📖 Usage

Basic usage of the triangle attention function:

import torch
from trifast import triangle_attention
from trifast.utils import gen_tensors

# Generate tensors (query, key, value, bias, mask)
q, k, v, bias, mask = gen_tensors(n, d, h, use_mask=True, device=device, dtype=dtype, std=scale)

# Apply triangle self attention
out = triangle_attention(q, k, v, bias, mask)

⚙️ Auto-tuning

TriFast modifies triton.autotune to cache the best config to disk. This means the first time the kernel is run will generally be slower than subsequent times (for a given input shape).

Using the auto-tuner

The package provides a command-line script for auto-tuning:

# Basic usage
trifast-tune

# For a more extensive tuning process
TRIFAST_FORCE_TUNE=1 trifast-tune

# With custom parameters
trifast-tune --min-n 32 --max-n 2048 --dtype bfloat16 --h 4,8 --d 32,64

Auto-tuner Parameters

--min-n          Minimum sequence length to tune (default: 16)
--max-n          Maximum sequence length to tune (default: 1024)
--dtype          PyTorch datatypes to use (comma-separated).
                 Options: float32, bfloat16, float16, all (default: bfloat16)
--h              List of number of heads (comma-separated integers, e.g., "1,2,4") (default: 4)
--d              List of dimensions (comma-separated integers, e.g., "16,32,64") (default: 32)

After tuning, the best configurations are cached to disk using platformdirs (e.g. under ~/.config/trifast/<version> on Linux).

🧠 How It Works

TriFast implements Triangle Self Attention using Triton to create optimized GPU kernels. It's essentially a Flash Attention applied to triangle self attention, resulting in significant performance gains compared to naive PyTorch implementations.

🔍 Implementation Details

This implementation draws inspiration from:

🗺️ Roadmap

  • Explore performance optimizations for dq/db/dkv transposed operations
  • Implement pipelined writes to global memory in backward kernels

👨‍💻 Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

📄 License

This project is licensed under the MIT License - see the LICENSE file for details.

🙏 Acknowledgements

  • The FlagOpen team for their work on FlagAttention
  • The Triton team for providing excellent documentation and tutorials
  • The DS4S team for their evoformer kernel implementation used in benchmarking

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trifast-0.1.13.dev1.tar.gz (428.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

trifast-0.1.13.dev1-py3-none-any.whl (16.3 kB view details)

Uploaded Python 3

File details

Details for the file trifast-0.1.13.dev1.tar.gz.

File metadata

  • Download URL: trifast-0.1.13.dev1.tar.gz
  • Upload date:
  • Size: 428.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.6.17

File hashes

Hashes for trifast-0.1.13.dev1.tar.gz
Algorithm Hash digest
SHA256 ba4d747d1dc1dc8807371c36d46b786056f8b53525f3d8aa8b5ac0fb41a505f7
MD5 1dd8810f126a82acc7270612e99c6eb7
BLAKE2b-256 8a82d18acf0e7a89edcd97a48e00aab877b6fdc3d88374329d5d856c151e57b2

See more details on using hashes here.

File details

Details for the file trifast-0.1.13.dev1-py3-none-any.whl.

File metadata

File hashes

Hashes for trifast-0.1.13.dev1-py3-none-any.whl
Algorithm Hash digest
SHA256 f7fc9be354ff3c1417be6a640ff92fcac394790f99293f9c2bc19987fc68cb62
MD5 c7ae9c366458b60d76d54d8349d05137
BLAKE2b-256 9276167040cfab59139bc013eae1744444a4e00bbe462789a4316b1cf4b4e704

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page