Skip to main content

Retention-based next-generation kernels for long-context models

Project description

Retention

Build

This repository contains a PyTorch layer implementing power retention, a linear-cost variant of attention whose state size can be controlled independently of context length and parameter count.

For details on the approach, see our paper: Scaling Context Requires Rethinking Attention

Documentation: https://m-a-n-i-f-e-s-t.github.io/retention/

Features

  • Efficient chunked algorithm for linear scaling with sequence length (O(t) cost vs O(t²) for standard attention)
  • Support for gated attention and rotary embeddings
  • CUDA kernels optimized for A100
  • FP16 and BF16 support

Installation

From PyPI (Recommended)

pip install retention

From Source

Requirements:

  • Python 3.11 or 3.12 (3.13 depends on the upcoming Triton 3.2 release)
  • CUDA Toolkit 12.4
  • GCC/G++ with C++17 support
  • Linux (Windows/MacOS not supported)
git clone https://github.com/manifest-ai/retention.git
cd retention
pip install -e .

All other dependencies (PyTorch, Ninja build system, etc.) will be automatically installed through pip.

Usage

The main entry point is the power_retention function, which implements symmetric power retention. Here's a basic example:

import torch
from retention import power_retention

# Create input tensors
batch_size = 2
seq_len = 1024
num_heads = 8
head_dim = 64

Q = torch.randn(batch_size, seq_len, num_heads, head_dim, device='cuda', dtype=torch.float16)
K = torch.randn_like(Q)
V = torch.randn_like(Q)

# Optional gating tensor
log_G = torch.nn.functional.logsigmoid(
    torch.randn(batch_size, seq_len, num_heads, dtype=torch.float32, device='cuda')
)

# Compute retention results
output = power_retention(
    Q=Q, K=K, V=V, 
    log_G=log_G,          # Optional gating tensor
    deg=2,                # Power parameter p
    chunk_size=128,       # Size of chunks for processing long sequences
)

Integration with Transformer Models

The package includes a drop-in replacement for standard attention in transformer models. See train/model.py for a complete example of using power retention in a GPT-style model:

from retention import power_retention

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        # ... initialization code ...
        
    def forward(self, x):
        # ... projection code ...
        
        # Use power retention instead of standard attention
        y = power_retention(
            Q=q, K=k, V=v, 
            log_G=log_g,
            deg=self.degree,
            chunk_size=self.chunk_size
        )
        
        # ... output projection ...
        return y

Development

Setup

The package uses pip's editable install mode for development. First, activate your Python virtual environment, then:

# Install base package in editable mode
pip install -e .

# Install development dependencies
pip install psutil
pip install flash_attn==2.7.3 --no-build-isolation
pip install -e .[dev]

Testing & Benchmarking

Run correctness tests:

pytest

Run benchmarks:

python -m perf.benchmark fwd          // Forward pass
python -m perf.benchmark bwd          // Backward pass
python -m perf.benchmark fwd+bwd      // Forward + backward pass

See benchmark for details.

Documentation

To view the documentation locally, run:

pip install mkdocs mkdocs-material
.venv/bin/mkdocs serve -a 0.0.0.0:8000

To update it publicly, run:

mkdocs gh-deploy

Training Example

To immediately see the kernel in action, cd train and use:

# Create the dataset first
python prepare_owt.py

# Single GPU training
python train.py \
  --batch_size=32 \
  --attention_kernel=power \
  --degree=2 \
  --chunk_size=128 \
  --disable_gating=False

# Multi-GPU training with DDP (example with 4 GPUs)
torchrun --standalone --nproc_per_node=4 train.py \
  --batch_size=32 \
  --attention_kernel=power \
  --degree=2 \
  --chunk_size=128 \
  --disable_gating=False

For distributed training across multiple nodes:

# On the first (master) node with IP 123.456.123.456:
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py

# On the worker node:
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py

Note: If your cluster does not have Infiniband interconnect, prepend NCCL_IB_DISABLE=1 to the commands.

Contributing

We welcome contributions! Here's how you can help:

Getting Started

  1. Fork the repository
  2. Create a new branch for your feature/fix: git checkout -b feature-name
  3. Install development dependencies: pip install -e .[dev]

Guidelines

  • Code Style: Follow PEP 8 for Python code. For CUDA code, follow the existing style in the codebase
  • Documentation: Add docstrings to new functions and update README if needed
  • Testing: Add tests for new features and ensure all tests pass
  • Benchmarking: If your code changes affect performance, delete the plots/benchmark_results and rerun some benchmarks with python -m perf.benchmark fwd+bwd
  • Commits: Write clear, concise commit messages
  • Performance: For CUDA kernels, include benchmarks showing performance impact

Pull Request Process

  1. Update documentation for any new features
  2. Add or update tests as needed
  3. Ensure all tests pass: pytest
  4. Run benchmarks if performance-critical code was changed: python3 -m perf.benchmark fwd+bwd
  5. Create a Pull Request with a clear description of changes
  6. Wait for review and address any feedback

Areas for Contribution

  • Performance optimizations for different GPU architectures
  • Documentation improvements
  • Bug fixes
  • Test coverage improvements

For major changes, please open an issue first to discuss what you would like to change.

Release Process

  1. Update the version in pyproject.toml
  2. Run pytest and benchmarks if applicable
  3. Run make release-test to build & push to Test PyPI for all Python targets
  4. Run make release to build & push to PyPI for all Python targets

Citation

If you use this code in your research, please cite:

@article{buckman2024symmetric,
  title={Symmetric Power Transformers},
  author={Buckman, Jacob and Gelada, Carles and Zhang, Sean},
  publisher={Manifest AI},
  year={2024},
  month={8},
  url={https://manifestai.com/articles/symmetric-power-transformers/}
}

License

Apache 2.0 (see LICENSE)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

retention-0.2.2-py3-none-any.whl (180.1 kB view details)

Uploaded Python 3

File details

Details for the file retention-0.2.2-py3-none-any.whl.

File metadata

  • Download URL: retention-0.2.2-py3-none-any.whl
  • Upload date:
  • Size: 180.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.9

File hashes

Hashes for retention-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 6428b6b747a3329ad36bc4e04a8349ef6dec36cb25579f379a1b4178638b2c53
MD5 6fa976ef6a097e07bf5a74e1d8b1fe72
BLAKE2b-256 1d429f3aafd5008de0c135f0d418aaa823f7143196688b413c32e5fb6b8e63cc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page