Skip to main content

Kernels for symmetric-power-based linear transformers

Project description

Power Attention

Build

A PyTorch extension implementing symmetric power transformers - a variant of linear transformers that achieves transformer-level performance while scaling linearly with sequence length. This package provides efficient CUDA kernels that make it possible to process much longer sequences compared to standard quadratic attention.

For details on the approach, see our paper: Symmetric Power Transformers

Installation

From PyPI (Recommended)

pip install power-attention

From Source

Requirements:

  • Python 3.11 or 3.12 (3.13 depends on the upcoming Triton 3.2 release)
  • CUDA Toolkit 12.4
  • GCC/G++ with C++17 support
  • Linux (Windows/MacOS not supported)
git clone https://github.com/manifest-ai/power-attention.git
cd power-attention
pip install -e .

All other dependencies (PyTorch, Ninja build system, etc.) will be automatically installed through pip.

Usage

The main entry point is the power_full function, which implements symmetric power attention. Here's a basic example:

import torch
from power_attention.power_full import power_full

# Create input tensors
batch_size = 2
seq_len = 1024
num_heads = 8
head_dim = 64

Q = torch.randn(batch_size, seq_len, num_heads, head_dim, device='cuda', dtype=torch.float16)
K = torch.randn_like(Q)
V = torch.randn_like(Q)

# Optional gating tensor
log_G = torch.nn.functional.logsigmoid(
    torch.randn(batch_size, seq_len, num_heads, dtype=torch.float32, device='cuda')
)

# Compute attention
output = power_full(
    Q=Q, K=K, V=V, 
    log_G=log_G,          # Optional gating tensor
    deg=4,                # Power parameter p
    chunk_size=128,       # Size of chunks for processing long sequences
)

Integration with Transformer Models

The package includes a drop-in replacement for standard attention in transformer models. See training/model.py for a complete example of using power attention in a GPT-style model:

from power_attention.power_full import power_full

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        # ... initialization code ...
        
    def forward(self, x):
        # ... projection code ...
        
        # Use power attention instead of standard attention
        y = power_full(
            Q=q, K=k, V=v, 
            log_G=log_g,
            deg=self.degree,
            chunk_size=self.chunk_size
        )
        
        # ... output projection ...
        return y

Features

  • Efficient chunked algorithm for linear scaling with sequence length (O(t) cost vs O(t²) for standard attention)
  • Support for gated attention and rotary embeddings
  • CUDA kernels optimized for A100
  • FP16 and BF16 support
  • Replacement for standard attention in transformer models is possible for fine-tuning

Development

Setup

The package uses pip's editable install mode for development. First, activate your Python virtual environment, then:

# Install base package in editable mode
pip install -e .

# Install development dependencies
pip install -e .[dev]

Testing & Benchmarking

Run correctness tests:

pytest perf/tests

Run benchmarks:

python3 -m perf.create_report # will only run on clean commits
python3 -m perf.plot_reports

Documentation

To view the documentation locally, run:

.venv/bin/mkdocs serve -a 0.0.0.0:8000

Training Example

To immediately see the kernel in action, cd training and use:

# Single GPU training
python train.py \
  --batch_size=32 \
  --attention_kernel=power \
  --degree=2 \
  --chunk_size=128 \
  --disable_gating=False \
  --out_dir=out/my_model

# Multi-GPU training with DDP (example with 4 GPUs)
torchrun --standalone --nproc_per_node=4 train.py \
  --batch_size=32 \
  --attention_kernel=power \
  --degree=2 \
  --chunk_size=128 \
  --disable_gating=False \
  --out_dir=out/my_model

Key training parameters:

  • attention_kernel: Use 'power' for symmetric power attention (default is 'sdpa' for standard attention)
  • degree: Power attention degree (default: 1)
  • chunk_size: Size of chunks for processing long sequences (default: None)
  • disable_gating: Set to true to disable gating mechanism (default: False)
  • batch_size: Batch size per GPU (default: 12)
  • block_size: Sequence length (default: 1024)
  • out_dir: Output directory for checkpoints and logs (default: 'out')
  • compile: Whether to use PyTorch 2.0 compilation for speed (default: True)
  • dtype: Data type for training - 'float32', 'bfloat16', or 'float16' (default: 'bfloat16' if supported, else 'float16')

For distributed training across multiple nodes:

# On the first (master) node with IP 123.456.123.456:
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py

# On the worker node:
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py

Note: If your cluster does not have Infiniband interconnect, prepend NCCL_IB_DISABLE=1 to the commands.

Contributing

We welcome contributions! Here's how you can help:

Getting Started

  1. Fork the repository
  2. Create a new branch for your feature/fix: git checkout -b feature-name
  3. Install development dependencies: pip install -e .[dev]

Guidelines

  • Code Style: Follow PEP 8 for Python code. For CUDA code, follow the existing style in the codebase
  • Documentation: Add docstrings to new functions and update README if needed
  • Testing: Add tests for new features and ensure all tests pass
  • Commits: Write clear, concise commit messages
  • Performance: For CUDA kernels, include benchmarks showing performance impact

Pull Request Process

  1. Update documentation for any new features
  2. Add or update tests as needed
  3. Ensure all tests pass: python3 -m pytest perf/tests
  4. Run benchmarks if performance-critical code was changed: python3 -m perf.create_report && python3 -m perf.plot_reports
  5. Create a Pull Request with a clear description of changes
  6. Wait for review and address any feedback

Areas for Contribution

  • Performance optimizations for different GPU architectures
  • Documentation improvements
  • Bug fixes
  • Test coverage improvements

For major changes, please open an issue first to discuss what you would like to change.

Release Process

  1. Update the version in pyproject.toml
  2. Run python3 -m pytest tests/ and benchmarks if applicable
  3. Run make release-test to build & push to Test PyPI for all Python targets
  4. Run make release to build & push to PyPI for all Python targets

Citation

If you use this code in your research, please cite:

@article{buckman2024symmetric,
  title={Symmetric Power Transformers},
  author={Buckman, Jacob and Gelada, Carles and Zhang, Sean},
  publisher={Manifest AI},
  year={2024},
  month={8},
  url={https://manifestai.com/articles/symmetric-power-transformers/}
}

License

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

power_attention-0.9.16.tar.gz (1.4 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

power_attention-0.9.16-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (75.8 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

power_attention-0.9.16-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (75.8 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

File details

Details for the file power_attention-0.9.16.tar.gz.

File metadata

  • Download URL: power_attention-0.9.16.tar.gz
  • Upload date:
  • Size: 1.4 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.11.4

File hashes

Hashes for power_attention-0.9.16.tar.gz
Algorithm Hash digest
SHA256 541b2d66d64da44cc8b54ef0e9e700d2e366123ac39ee4f5b94e41be855bf984
MD5 a9b99150dbf19f1ef19d675fb7f2555f
BLAKE2b-256 7ad928c1443f07a7994329242296d279f9923635e2423920249c66b6fb0f1a59

See more details on using hashes here.

File details

Details for the file power_attention-0.9.16-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for power_attention-0.9.16-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 bbfd0724089624cebbfbbc3e5be25f9b249d52ca5790b9f6dccf8037142bbf6b
MD5 83e00eae77ad90b71b514562eecebe1d
BLAKE2b-256 f4c15cd15dae68a5dc25b1d1a9fe21a9f84daefce9e1d94301e0f1ea760f19ee

See more details on using hashes here.

File details

Details for the file power_attention-0.9.16-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for power_attention-0.9.16-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 b05c7ed1cf4addb672927b702c68bfe1a127dd5e90df0f3fce2e00b1f5a408a8
MD5 1416e9fc0e07ccbf8c926848c48c4c24
BLAKE2b-256 beefe4a0347cef4a0bd72faeb12be4b10eceacd50cdb99ee2795ea9bd02a04c4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page