Skip to main content

Helpful tools and examples for working with flex-attention

Project description

Attention Gym

Attention Gym is a collection of helpful tools and examples for working with flex-attention

🎯 Features | 🚀 Getting Started | 💻 Usage | 🛠️ Dev | 🤝 Contributing | ⚖️ License

📖 Overview

This repository aims to provide a playground for experimenting with various attention mechanisms using the FlexAttention API. It includes implementations of different attention variants, performance comparisons, and utility functions to help researchers and developers explore and optimize attention mechanisms in their models.

favorite

🎯 Features

  • Implementations of various attention mechanisms using FlexAttention
  • Utility functions for creating and combining attention masks
  • Examples of how to use FlexAttention in real-world scenarios

🚀 Getting Started

Prerequisites

  • PyTorch (version 2.5 or higher)

Installation

git clone https://github.com/pytorch-labs/attention-gym.git
cd attention-gym
pip install .

💻 Usage

There are two main ways to use Attention Gym:

  1. Run Example Scripts: Many files in the project can be executed directly to demonstrate their functionality:

    python attn_gym/masks/document_mask.py
    

    These scripts often generate visualizations to help you understand the attention mechanisms.

  2. Import in Your Projects: You can use Attention Gym components in your own work by importing them:

    from torch.nn.attention.flex_attention import flex_attention, create_block_mask
    from attn_gym.masks import generate_sliding_window
    
    # Use the imported function in your code
    sliding_window_mask_mod = generate_sliding_window(window_size=1024)
    block_mask = create_block_mask(sliding_window_mask_mod, 1, 1, S, S, device=device)
    out = flex_attention(query, key, value, block_mask=block_mask)
    

For comprehensive examples of using FlexAttention in real-world scenarios, explore the examples/ directory. These end-to-end implementations showcase how to integrate various attention mechanisms into your models.

Note

Attention Gym is under active development, and we do not currently offer any backward compatibility guarantees. APIs and functionalities may change between versions. We recommend pinning to a specific version in your projects and carefully reviewing changes when upgrading.

📁 Structure

Attention Gym is organized for easy exploration of attention mechanisms:

🔍 Key Locations

  • attn_gym.masks: Examples creating BlockMasks
  • attn_gym.mods: Examples creating score_mods
  • attn_gym.paged_attention: Examples using PagedAttention
  • examples/: Detailed implementations using FlexAttention

🛠️ Dev

Install dev requirements

pip install -e ".[dev]"

Install pre-commit hooks

pre-commit install

🤝 Contributing

We welcome contributions to Attention Gym, especially new Masks or score mods! Here's how you can contribute:

Contributing Mods

  1. Create a new file in the attn_gym/masks/ for mask_mods or attn_gym/mods/ for score_mods.
  2. Implement your function, and add a simple main function that showcases your new function.
  3. Update the attn_gym/*/__init__.py file to include your new function.
  4. [Optinally] Add an end to end example using your new func in the examples/ directory.

See CONTRIBUTING.md for more details.

⚖️ License

attention-gym is released under the BSD 3-Clause License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

attn_gym-0.0.4.tar.gz (45.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

attn_gym-0.0.4-py3-none-any.whl (32.5 kB view details)

Uploaded Python 3

File details

Details for the file attn_gym-0.0.4.tar.gz.

File metadata

  • Download URL: attn_gym-0.0.4.tar.gz
  • Upload date:
  • Size: 45.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for attn_gym-0.0.4.tar.gz
Algorithm Hash digest
SHA256 09015359f33685f3034e70f64c4bf1424d24f166c36382330a32f3567ddb06d9
MD5 214ae182ee86fbf4eacd096c585c2804
BLAKE2b-256 d49de97776799ff06b7c1f6a672d0b12fafa33f2b4abb1a1dd5fd4958af4bc92

See more details on using hashes here.

Provenance

The following attestation bundles were made for attn_gym-0.0.4.tar.gz:

Publisher: publish-to-pypi.yml on pytorch-labs/attention-gym

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file attn_gym-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: attn_gym-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 32.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for attn_gym-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 92edf94ee4baf8f8e75ebb04d12c82157717aa8c21e548959ef4b562951ce417
MD5 0a2c49540cb49033c2c57d2013524ab7
BLAKE2b-256 03e01062b559df763fa3da4bfe30e88582e0613ab291bcff983aa51ba20881b6

See more details on using hashes here.

Provenance

The following attestation bundles were made for attn_gym-0.0.4-py3-none-any.whl:

Publisher: publish-to-pypi.yml on pytorch-labs/attention-gym

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page