Universal Set Transformers for scalable set/multiset modeling
Project description
Universal Set Transformers
Universal Set Transformers (UST) is a PyTorch library for processing sets and multisets with transformer-based architectures. This library provides efficient implementations of set attention mechanisms that are permutation invariant and can handle sets of varying sizes.
Features
- Set and Multiset Processing: Process sets and multisets with explicit multiplicity handling
- Minibatch Consistency: Process large sets in minibatches with guaranteed consistency
- Efficient Attention Mechanisms: Multiple attention implementations (softmax, sigmoid, flash attention)
- State Management: Explicit state handling for incremental processing
- Numerical Stability: Robust handling of large values and edge cases
Installation
pip install universal-set-transformers
Or install from source:
git clone https://github.com/yourusername/universal-set-transformers.git
cd universal-set-transformers
pip install -e .
Quick Start
Basic Usage
import torch
from ust.modules import SelfSetAttentionBlock, CrossSetAttentionBlock
# Create a self-attention block
self_attn = SelfSetAttentionBlock(
d_model=32,
nhead=4,
dim_feedforward=64
)
# Process a set
x = torch.randn(2, 10, 32) # batch_size=2, set_size=10, d_model=32
output = self_attn(x)
# Process a multiset with multiplicities
multiplicities = torch.randint(1, 5, (2, 10)).float()
output_with_mult = self_attn(x, multiplicities)
Cross-Set Attention
import torch
from ust.modules import CrossSetAttentionBlock
# Create a cross-attention block
cross_attn = CrossSetAttentionBlock(
d_model=32,
nhead=4,
dim_feedforward=64
)
# Process two sets
x = torch.randn(2, 5, 32) # batch_size=2, set_size=5, d_model=32
y = torch.randn(2, 10, 32) # batch_size=2, set_size=10, d_model=32
output = cross_attn(x, y)
# Process with multiplicities
y_multiplicities = torch.randint(1, 5, (2, 10)).float()
output_with_mult = cross_attn(x, y, y_multiplicities)
Minibatch Processing
import torch
from ust.modules import CrossSetAttentionBlock
# Create a cross-attention block
cross_attn = CrossSetAttentionBlock(
d_model=32,
nhead=4,
dim_feedforward=64
)
# Query set
x = torch.randn(2, 5, 32) # batch_size=2, set_size=5, d_model=32
# Initialize state
state = cross_attn.attention.scaled_dot_product_attention.initial_state()
# Process first minibatch
y1 = torch.randn(2, 10, 32)
m1 = torch.randint(1, 5, (2, 10)).float()
output1, state = cross_attn(x, y1, m1, state=state)
# Process second minibatch
y2 = torch.randn(2, 8, 32)
m2 = torch.randint(1, 5, (2, 8)).float()
output2, state = cross_attn(x, y2, m2, state=state)
# Final output is output2
Architecture
The library is organized into several modules:
ust.modules: Core modules for set processingscaled_dot_product_attention.py: Various attention mechanismsset_transformer.py: Set Transformer implementation
ust.api: Abstract interfaces and type definitionsust.utils: Utility functions
Key Components
Attention Mechanisms
ScaledDotProductSoftmaxSetAttention: Standard softmax attention for setsScaledDotProductSoftmaxFlashSetAttention: Optimized flash attention for setsScaledDotProductSigmoidSetAttention: Sigmoid attention for sets
Attention Blocks
MultiheadSetAttention: Multi-head attention for setsSelfSetAttentionBlock: Self-attention block for setsCrossSetAttentionBlock: Cross-attention block for sets
Set Transformer Components
InducedSetAttentionBlock: Induced set attention blockPoolingByMultiheadSetAttention: Pooling mechanism for sets
Advanced Features
Explicit State Handling
import torch
from ust.modules.scaled_dot_product_attention import ScaledDotProductSoftmaxSetAttention
# Create attention mechanism
attention = ScaledDotProductSoftmaxSetAttention()
# Initialize state
state = attention.initial_state()
# Process query and key-value pairs
query = torch.randn(2, 5, 32)
key = torch.randn(2, 10, 32)
value = torch.randn(2, 10, 32)
# Update state with first batch
state = attention.compute_aggregated_attention(query, key, value, None, state)
# Update state with second batch
key2 = torch.randn(2, 8, 32)
value2 = torch.randn(2, 8, 32)
state = attention.compute_aggregated_attention(query, key2, value2, None, state)
# Get final result
output = attention.get(state)
Multiset Processing
import torch
from ust.modules import SelfSetAttentionBlock
# Create a self-attention block
self_attn = SelfSetAttentionBlock(
d_model=32,
nhead=4,
dim_feedforward=64
)
# Process a set with duplicates
x = torch.randn(2, 10, 32)
multiplicities = torch.tensor([
[1, 2, 1, 3, 1, 1, 2, 1, 1, 1],
[2, 1, 1, 1, 3, 2, 1, 1, 1, 2]
]).float()
# Process with multiplicities
output = self_attn(x, multiplicities)
License
This project is licensed under the MIT License - see the LICENSE file for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file universal_set_transformers-0.3.0.tar.gz.
File metadata
- Download URL: universal_set_transformers-0.3.0.tar.gz
- Upload date:
- Size: 26.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
db8a6f5ea94fbb54e583a5d564b293c31aed44c708fcad381952404bd1bdd846
|
|
| MD5 |
3f6f45cb9f9fbb3ad41f088539e990bd
|
|
| BLAKE2b-256 |
2626c1a2d9dcb30a2d22141a8f9a1434c2dfb14763b6c2db6a3140e22988c0a0
|
File details
Details for the file universal_set_transformers-0.3.0-py3-none-any.whl.
File metadata
- Download URL: universal_set_transformers-0.3.0-py3-none-any.whl
- Upload date:
- Size: 25.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d4c8d39d1bf2f72bced3b34af3ff028a8261e90a73b2fb25e8816e41915c106f
|
|
| MD5 |
c9fabf188651af5e69e95aecfd30bb5b
|
|
| BLAKE2b-256 |
c797c78d13cf8e6d95a3b7ea51987e06a0da259bc5d32ba9f0f097c5f8a17a3c
|