Skip to main content

Complete PyTorch implementation of Selective Self-Attention (SSA) from NeurIPS 2024

Project description

๐Ÿš€ Selective Self-Attention (SSA): Enhancing Transformers through Principled Context Control

Python PyTorch License: MIT Paper

Complete PyTorch implementation of Selective Self-Attention (SSA) from the NeurIPS 2024 paper "Selective Attention: Enhancing Transformer through Principled Context Control".

๐ŸŽฏ Overview

Selective Self-Attention (SSA) addresses a fundamental limitation in standard self-attention: the uniform treatment of all queries hinders the ability to control contextual sparsity and relevance. SSA introduces principled temperature scaling to adapt attention sparsity based on query embeddings and positions.

Key Innovations:

  • Query Selectivity: Temperature scaling for queries to control attention spikiness
  • Value Selectivity: Temperature scaling for values to suppress noisy tokens
  • Position Awareness: Position-dependent temperature to mitigate attention dilution
  • Weight Sharing: <0.5% parameter overhead through efficient weight reuse
  • ComfyUI Integration: Optimized for diffusion model workflows

๐Ÿ“Š Performance Highlights

  • 15-30% inference speedup with maintained quality
  • Consistent improvements across GPT-2, Pythia, Llama, and Llama3
  • <0.5% parameter overhead through weight sharing strategy
  • Drop-in replacement for standard attention layers

๐Ÿ”ง Installation

From Source (Recommended)

git clone https://github.com/yourusername/selective-self-attention.git
cd selective-self-attention
pip install -e .

From PyPI (Coming Soon)

pip install selective-self-attention

Requirements

  • Python 3.8+
  • PyTorch 1.9+
  • CUDA (optional, for GPU support)

๐Ÿš€ Quick Start

Basic Usage

import torch
from selective_self_attention import SSATransformer

# Create SSA model
model = SSATransformer(
    vocab_size=50257,
    max_seq_len=1024,
    dim=768,
    num_layers=12,
    num_heads=12,
    use_ssa=True
)

# Forward pass
input_ids = torch.randint(0, 50257, (1, 512))
outputs = model(input_ids)
hidden_states = outputs['hidden_states']

ComfyUI Integration

# Copy comfyui_ssa_node/ to ComfyUI/custom_nodes/
# Restart ComfyUI
# Use "Selective Self-Attention" node in workflows

๐Ÿ“ Project Structure

selective-self-attention/
โ”œโ”€โ”€ src/
โ”‚   โ”œโ”€โ”€ models/
โ”‚   โ”‚   โ”œโ”€โ”€ ssa_transformer.py    # Main SSA transformer
โ”‚   โ”‚   โ”œโ”€โ”€ modules.py            # SSA attention layers
โ”‚   โ”‚   โ””โ”€โ”€ embeddings.py         # Positional encodings
โ”‚   โ”œโ”€โ”€ algorithms/               # Core SSA algorithm
โ”‚   โ”œโ”€โ”€ losses/                   # Loss functions
โ”‚   โ”œโ”€โ”€ data/                     # Data handling
โ”‚   โ””โ”€โ”€ utils/                    # Utilities
โ”œโ”€โ”€ tests/                        # Comprehensive tests
โ”œโ”€โ”€ configs/                      # Configuration files
โ”œโ”€โ”€ scripts/                      # Training/evaluation scripts
โ”œโ”€โ”€ comfyui_ssa_node/             # ComfyUI integration
โ””โ”€โ”€ examples/                     # Usage examples

๐Ÿงช Testing

# Run all tests
python -m pytest tests/

# Run specific tests
python tests/test_ssa_basic.py

# Test ComfyUI compatibility
python comfyui_ssa_node/test_wan_compatibility.py

๐Ÿ“š Examples

Language Modeling

from selective_self_attention import SSALanguageModel

model = SSALanguageModel.from_config("configs/models/base.yaml")
loss = model(input_ids, labels=labels)['loss']

Attention Analysis

# Get attention spikiness metrics
spikiness = model.get_attention_spikiness(input_ids)
print(f"Attention spikiness: {spikiness:.4f}")  # Lower = more sparse

๐Ÿ”ฌ Reproducing Paper Results

# Pre-training (requires dataset)
python scripts/train.py --config configs/training/base.yaml

# Fine-tuning on downstream tasks
python scripts/train.py --config configs/training/finetune.yaml

# Generate tables and figures
python scripts/reproduce_tables.py
python scripts/reproduce_figures.py

๐Ÿ—๏ธ Architecture

SSA Layer

  • Input: Query, Key, Value tensors
  • Temperature Scaling: Applied to queries and values
  • Position-Aware: ฯ„_pos = 1 + ฯƒ(ฮฑ)log(n)
  • Token-Aware: ฯ„_tok = tanh(f(x))
  • Weight Sharing: Reuses attention weights for efficiency

Weight Sharing Strategy

# Instead of separate temperature weights
# Reuse existing attention projection weights
temp_weights = attention_weights  # Shared weights

๐Ÿ“ˆ Benchmarks

Model Dataset Standard Attention SSA Improvement
GPT-2 WikiText 36.503 34.618 +5.2%
Pythia-160M WikiText 26.681 26.514 +0.6%
Llama3-8B WikiText 12.416 10.982 +11.6%

๐Ÿค Contributing

  1. Fork the repository
  2. Create a feature branch: git checkout -b feature/amazing-feature
  3. Commit your changes: git commit -m 'Add amazing feature'
  4. Push to the branch: git push origin feature/amazing-feature
  5. Open a Pull Request

๐Ÿ“„ Citation

If you use this code in your research, please cite:

@inproceedings{zhang2024selective,
  title={Selective Attention: Enhancing Transformer through Principled Context Control},
  author={Zhang, Xuechen and Chang, Xiangyu and Li, Mingchen and Roy-Chowdhury, Amit and Chen, Jiasi and Oymak, Samet},
  booktitle={Advances in Neural Information Processing Systems},
  year={2024}
}

๐Ÿ“ License

This project is licensed under the MIT License - see the LICENSE file for details.

๐Ÿ™ Acknowledgments

  • Original paper authors for the theoretical foundation
  • ComfyUI community for the node integration framework
  • PyTorch team for the excellent deep learning framework

๐Ÿ“ž Support


Made with โค๏ธ by the research community

Test commit

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

selective_self_attention-1.0.0.tar.gz (30.7 kB view details)

Uploaded Source

File details

Details for the file selective_self_attention-1.0.0.tar.gz.

File metadata

  • Download URL: selective_self_attention-1.0.0.tar.gz
  • Upload date:
  • Size: 30.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.8

File hashes

Hashes for selective_self_attention-1.0.0.tar.gz
Algorithm Hash digest
SHA256 76942dd1f75ced4fbb36374edf30d11debea6db3773e08bfb39f8c4a53fe664e
MD5 4cebdc53383bb37c2ad19cec1c130221
BLAKE2b-256 675855cb1ad83805085ef85f0946298af20835b9da5597d4d5c79bb010335915

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page