Complete PyTorch implementation of Selective Self-Attention (SSA) from NeurIPS 2024
Project description
๐ Selective Self-Attention (SSA): Enhancing Transformers through Principled Context Control
Complete PyTorch implementation of Selective Self-Attention (SSA) from the NeurIPS 2024 paper "Selective Attention: Enhancing Transformer through Principled Context Control".
๐ฏ Overview
Selective Self-Attention (SSA) addresses a fundamental limitation in standard self-attention: the uniform treatment of all queries hinders the ability to control contextual sparsity and relevance. SSA introduces principled temperature scaling to adapt attention sparsity based on query embeddings and positions.
Key Innovations:
- Query Selectivity: Temperature scaling for queries to control attention spikiness
- Value Selectivity: Temperature scaling for values to suppress noisy tokens
- Position Awareness: Position-dependent temperature to mitigate attention dilution
- Weight Sharing: <0.5% parameter overhead through efficient weight reuse
- ComfyUI Integration: Optimized for diffusion model workflows
๐ Performance Highlights
- 15-30% inference speedup with maintained quality
- Consistent improvements across GPT-2, Pythia, Llama, and Llama3
- <0.5% parameter overhead through weight sharing strategy
- Drop-in replacement for standard attention layers
๐ง Installation
From Source (Recommended)
git clone https://github.com/yourusername/selective-self-attention.git
cd selective-self-attention
pip install -e .
From PyPI (Coming Soon)
pip install selective-self-attention
Requirements
- Python 3.8+
- PyTorch 1.9+
- CUDA (optional, for GPU support)
๐ Quick Start
Basic Usage
import torch
from selective_self_attention import SSATransformer
# Create SSA model
model = SSATransformer(
vocab_size=50257,
max_seq_len=1024,
dim=768,
num_layers=12,
num_heads=12,
use_ssa=True
)
# Forward pass
input_ids = torch.randint(0, 50257, (1, 512))
outputs = model(input_ids)
hidden_states = outputs['hidden_states']
ComfyUI Integration
# Copy comfyui_ssa_node/ to ComfyUI/custom_nodes/
# Restart ComfyUI
# Use "Selective Self-Attention" node in workflows
๐ Project Structure
selective-self-attention/
โโโ src/
โ โโโ models/
โ โ โโโ ssa_transformer.py # Main SSA transformer
โ โ โโโ modules.py # SSA attention layers
โ โ โโโ embeddings.py # Positional encodings
โ โโโ algorithms/ # Core SSA algorithm
โ โโโ losses/ # Loss functions
โ โโโ data/ # Data handling
โ โโโ utils/ # Utilities
โโโ tests/ # Comprehensive tests
โโโ configs/ # Configuration files
โโโ scripts/ # Training/evaluation scripts
โโโ comfyui_ssa_node/ # ComfyUI integration
โโโ examples/ # Usage examples
๐งช Testing
# Run all tests
python -m pytest tests/
# Run specific tests
python tests/test_ssa_basic.py
# Test ComfyUI compatibility
python comfyui_ssa_node/test_wan_compatibility.py
๐ Examples
Language Modeling
from selective_self_attention import SSALanguageModel
model = SSALanguageModel.from_config("configs/models/base.yaml")
loss = model(input_ids, labels=labels)['loss']
Attention Analysis
# Get attention spikiness metrics
spikiness = model.get_attention_spikiness(input_ids)
print(f"Attention spikiness: {spikiness:.4f}") # Lower = more sparse
๐ฌ Reproducing Paper Results
# Pre-training (requires dataset)
python scripts/train.py --config configs/training/base.yaml
# Fine-tuning on downstream tasks
python scripts/train.py --config configs/training/finetune.yaml
# Generate tables and figures
python scripts/reproduce_tables.py
python scripts/reproduce_figures.py
๐๏ธ Architecture
SSA Layer
- Input: Query, Key, Value tensors
- Temperature Scaling: Applied to queries and values
- Position-Aware:
ฯ_pos = 1 + ฯ(ฮฑ)log(n) - Token-Aware:
ฯ_tok = tanh(f(x)) - Weight Sharing: Reuses attention weights for efficiency
Weight Sharing Strategy
# Instead of separate temperature weights
# Reuse existing attention projection weights
temp_weights = attention_weights # Shared weights
๐ Benchmarks
| Model | Dataset | Standard Attention | SSA | Improvement |
|---|---|---|---|---|
| GPT-2 | WikiText | 36.503 | 34.618 | +5.2% |
| Pythia-160M | WikiText | 26.681 | 26.514 | +0.6% |
| Llama3-8B | WikiText | 12.416 | 10.982 | +11.6% |
๐ค Contributing
- Fork the repository
- Create a feature branch:
git checkout -b feature/amazing-feature - Commit your changes:
git commit -m 'Add amazing feature' - Push to the branch:
git push origin feature/amazing-feature - Open a Pull Request
๐ Citation
If you use this code in your research, please cite:
@inproceedings{zhang2024selective,
title={Selective Attention: Enhancing Transformer through Principled Context Control},
author={Zhang, Xuechen and Chang, Xiangyu and Li, Mingchen and Roy-Chowdhury, Amit and Chen, Jiasi and Oymak, Samet},
booktitle={Advances in Neural Information Processing Systems},
year={2024}
}
๐ License
This project is licensed under the MIT License - see the LICENSE file for details.
๐ Acknowledgments
- Original paper authors for the theoretical foundation
- ComfyUI community for the node integration framework
- PyTorch team for the excellent deep learning framework
๐ Support
- Issues: GitHub Issues
- Discussions: GitHub Discussions
- Email: your-email@example.com
Made with โค๏ธ by the research community
Test commit
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file selective_self_attention-1.0.0.tar.gz.
File metadata
- Download URL: selective_self_attention-1.0.0.tar.gz
- Upload date:
- Size: 30.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
76942dd1f75ced4fbb36374edf30d11debea6db3773e08bfb39f8c4a53fe664e
|
|
| MD5 |
4cebdc53383bb37c2ad19cec1c130221
|
|
| BLAKE2b-256 |
675855cb1ad83805085ef85f0946298af20835b9da5597d4d5c79bb010335915
|