Kernels for symmetric-power-based linear transformers
Project description
Power Attention
A PyTorch extension implementing symmetric power transformers - a variant of linear transformers that achieves transformer-level performance while scaling linearly with sequence length. This package provides efficient CUDA kernels that make it possible to process much longer sequences compared to standard quadratic attention.
For details on the approach, see our paper: Symmetric Power Transformers
Installation
From PyPI (Recommended)
pip install power-attention
From Source
Requirements:
- Python 3.11 or 3.12 (3.13 depends on the upcoming Triton 3.2 release)
- CUDA Toolkit 12.4
- GCC/G++ with C++17 support
- Linux (Windows/MacOS not supported)
git clone https://github.com/manifest-ai/power-attention.git
cd power-attention
pip install -e .
All other dependencies (PyTorch, Ninja build system, etc.) will be automatically installed through pip.
Usage
The main entry point is the power_full function, which implements symmetric power attention. Here's a basic example:
import torch
from power_attention.power_full import power_full
# Create input tensors
batch_size = 2
seq_len = 1024
num_heads = 8
head_dim = 64
Q = torch.randn(batch_size, seq_len, num_heads, head_dim, device='cuda', dtype=torch.float16)
K = torch.randn_like(Q)
V = torch.randn_like(Q)
# Optional gating tensor
log_G = torch.nn.functional.logsigmoid(
torch.randn(batch_size, seq_len, num_heads, dtype=torch.float32, device='cuda')
)
# Compute attention
output = power_full(
Q=Q, K=K, V=V,
log_G=log_G, # Optional gating tensor
deg=4, # Power parameter p
chunk_size=128, # Size of chunks for processing long sequences
)
Integration with Transformer Models
The package includes a drop-in replacement for standard attention in transformer models. See training/model.py for a complete example of using power attention in a GPT-style model:
from power_attention.power_full import power_full
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
# ... initialization code ...
def forward(self, x):
# ... projection code ...
# Use power attention instead of standard attention
y = power_full(
Q=q, K=k, V=v,
log_G=log_g,
deg=self.degree,
chunk_size=self.chunk_size
)
# ... output projection ...
return y
Features
- Efficient chunked algorithm for linear scaling with sequence length (O(t) cost vs O(t²) for standard attention)
- Support for gated attention and rotary embeddings
- CUDA kernels optimized for A100
- FP16 and BF16 support
- Replacement for standard attention in transformer models is possible for fine-tuning
Development
Setup
The package uses pip's editable install mode for development. First, activate your Python virtual environment, then:
# Install base package in editable mode
pip install -e .
# Install development dependencies
pip install -e .[dev]
Testing & Benchmarking
Run correctness tests:
pytest perf/tests
Run benchmarks:
python3 -m perf.create_report # will only run on clean commits
python3 -m perf.plot_reports
Documentation
To view the documentation locally, run:
.venv/bin/mkdocs serve -a 0.0.0.0:8000
Training Example
To immediately see the kernel in action, cd training and use:
# Single GPU training
python train.py \
--batch_size=32 \
--attention_kernel=power \
--degree=2 \
--chunk_size=128 \
--disable_gating=False \
--out_dir=out/my_model
# Multi-GPU training with DDP (example with 4 GPUs)
torchrun --standalone --nproc_per_node=4 train.py \
--batch_size=32 \
--attention_kernel=power \
--degree=2 \
--chunk_size=128 \
--disable_gating=False \
--out_dir=out/my_model
Key training parameters:
attention_kernel: Use 'power' for symmetric power attention (default is 'sdpa' for standard attention)degree: Power attention degree (default: 1)chunk_size: Size of chunks for processing long sequences (default: None)disable_gating: Set to true to disable gating mechanism (default: False)batch_size: Batch size per GPU (default: 12)block_size: Sequence length (default: 1024)out_dir: Output directory for checkpoints and logs (default: 'out')compile: Whether to use PyTorch 2.0 compilation for speed (default: True)dtype: Data type for training - 'float32', 'bfloat16', or 'float16' (default: 'bfloat16' if supported, else 'float16')
For distributed training across multiple nodes:
# On the first (master) node with IP 123.456.123.456:
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py
# On the worker node:
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py
Note: If your cluster does not have Infiniband interconnect, prepend NCCL_IB_DISABLE=1 to the commands.
Contributing
We welcome contributions! Here's how you can help:
Getting Started
- Fork the repository
- Create a new branch for your feature/fix:
git checkout -b feature-name - Install development dependencies:
pip install -e .[dev]
Guidelines
- Code Style: Follow PEP 8 for Python code. For CUDA code, follow the existing style in the codebase
- Documentation: Add docstrings to new functions and update README if needed
- Testing: Add tests for new features and ensure all tests pass
- Commits: Write clear, concise commit messages
- Performance: For CUDA kernels, include benchmarks showing performance impact
Pull Request Process
- Update documentation for any new features
- Add or update tests as needed
- Ensure all tests pass:
python3 -m pytest perf/tests - Run benchmarks if performance-critical code was changed:
python3 -m perf.create_report && python3 -m perf.plot_reports - Create a Pull Request with a clear description of changes
- Wait for review and address any feedback
Areas for Contribution
- Performance optimizations for different GPU architectures
- Documentation improvements
- Bug fixes
- Test coverage improvements
For major changes, please open an issue first to discuss what you would like to change.
Release Process
- Update the version in
pyproject.toml - Run
python3 -m pytest tests/and benchmarks if applicable - Run
make release-testto build & push to Test PyPI for all Python targets - Run
make releaseto build & push to PyPI for all Python targets
Citation
If you use this code in your research, please cite:
@article{buckman2024symmetric,
title={Symmetric Power Transformers},
author={Buckman, Jacob and Gelada, Carles and Zhang, Sean},
publisher={Manifest AI},
year={2024},
month={8},
url={https://manifestai.com/articles/symmetric-power-transformers/}
}
License
MIT License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file power_attention-0.9.16.tar.gz.
File metadata
- Download URL: power_attention-0.9.16.tar.gz
- Upload date:
- Size: 1.4 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.11.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
541b2d66d64da44cc8b54ef0e9e700d2e366123ac39ee4f5b94e41be855bf984
|
|
| MD5 |
a9b99150dbf19f1ef19d675fb7f2555f
|
|
| BLAKE2b-256 |
7ad928c1443f07a7994329242296d279f9923635e2423920249c66b6fb0f1a59
|
File details
Details for the file power_attention-0.9.16-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.
File metadata
- Download URL: power_attention-0.9.16-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- Upload date:
- Size: 75.8 MB
- Tags: CPython 3.12, manylinux: glibc 2.17+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.11.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bbfd0724089624cebbfbbc3e5be25f9b249d52ca5790b9f6dccf8037142bbf6b
|
|
| MD5 |
83e00eae77ad90b71b514562eecebe1d
|
|
| BLAKE2b-256 |
f4c15cd15dae68a5dc25b1d1a9fe21a9f84daefce9e1d94301e0f1ea760f19ee
|
File details
Details for the file power_attention-0.9.16-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.
File metadata
- Download URL: power_attention-0.9.16-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- Upload date:
- Size: 75.8 MB
- Tags: CPython 3.11, manylinux: glibc 2.17+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.11.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b05c7ed1cf4addb672927b702c68bfe1a127dd5e90df0f3fce2e00b1f5a408a8
|
|
| MD5 |
1416e9fc0e07ccbf8c926848c48c4c24
|
|
| BLAKE2b-256 |
beefe4a0347cef4a0bd72faeb12be4b10eceacd50cdb99ee2795ea9bd02a04c4
|