Skip to main content

Universal Set Transformers for scalable set/multiset modeling

Project description

Universal Set Transformers

License: MIT Python 3.8+ PyTorch 2.5+

Universal Set Transformers (UST) is a PyTorch library for processing sets and multisets with transformer-based architectures. This library provides efficient implementations of set attention mechanisms that are permutation invariant and can handle sets of varying sizes.

Features

  • Set and Multiset Processing: Process sets and multisets with explicit multiplicity handling
  • Minibatch Consistency: Process large sets in minibatches with guaranteed consistency
  • Efficient Attention Mechanisms: Multiple attention implementations (softmax, sigmoid, flash attention)
  • State Management: Explicit state handling for incremental processing
  • Numerical Stability: Robust handling of large values and edge cases

Installation

pip install universal-set-transformers

Or install from source:

git clone https://github.com/yourusername/universal-set-transformers.git
cd universal-set-transformers
pip install -e .

Quick Start

Basic Usage

import torch
from ust.modules import SelfSetAttentionBlock, CrossSetAttentionBlock

# Create a self-attention block
self_attn = SelfSetAttentionBlock(
    d_model=32,
    nhead=4,
    dim_feedforward=64
)

# Process a set
x = torch.randn(2, 10, 32)  # batch_size=2, set_size=10, d_model=32
output = self_attn(x)

# Process a multiset with multiplicities
multiplicities = torch.randint(1, 5, (2, 10)).float()
output_with_mult = self_attn(x, multiplicities)

Cross-Set Attention

import torch
from ust.modules import CrossSetAttentionBlock

# Create a cross-attention block
cross_attn = CrossSetAttentionBlock(
    d_model=32,
    nhead=4,
    dim_feedforward=64
)

# Process two sets
x = torch.randn(2, 5, 32)   # batch_size=2, set_size=5, d_model=32
y = torch.randn(2, 10, 32)  # batch_size=2, set_size=10, d_model=32
output = cross_attn(x, y)

# Process with multiplicities
y_multiplicities = torch.randint(1, 5, (2, 10)).float()
output_with_mult = cross_attn(x, y, y_multiplicities)

Minibatch Processing

import torch
from ust.modules import CrossSetAttentionBlock

# Create a cross-attention block
cross_attn = CrossSetAttentionBlock(
    d_model=32,
    nhead=4,
    dim_feedforward=64
)

# Query set
x = torch.randn(2, 5, 32)  # batch_size=2, set_size=5, d_model=32

# Initialize state
state = cross_attn.attention.scaled_dot_product_attention.initial_state()

# Process first minibatch
y1 = torch.randn(2, 10, 32)
m1 = torch.randint(1, 5, (2, 10)).float()
output1, state = cross_attn(x, y1, m1, state=state)

# Process second minibatch
y2 = torch.randn(2, 8, 32)
m2 = torch.randint(1, 5, (2, 8)).float()
output2, state = cross_attn(x, y2, m2, state=state)

# Final output is output2

Architecture

The library is organized into several modules:

  • ust.modules: Core modules for set processing
    • scaled_dot_product_attention.py: Various attention mechanisms
    • set_transformer.py: Set Transformer implementation
  • ust.api: Abstract interfaces and type definitions
  • ust.utils: Utility functions

Key Components

Attention Mechanisms

  • ScaledDotProductSoftmaxSetAttention: Standard softmax attention for sets
  • ScaledDotProductSoftmaxFlashSetAttention: Optimized flash attention for sets
  • ScaledDotProductSigmoidSetAttention: Sigmoid attention for sets

Attention Blocks

  • MultiheadSetAttention: Multi-head attention for sets
  • SelfSetAttentionBlock: Self-attention block for sets
  • CrossSetAttentionBlock: Cross-attention block for sets

Set Transformer Components

  • InducedSetAttentionBlock: Induced set attention block
  • PoolingByMultiheadSetAttention: Pooling mechanism for sets

Advanced Features

Explicit State Handling

import torch
from ust.modules.scaled_dot_product_attention import ScaledDotProductSoftmaxSetAttention

# Create attention mechanism
attention = ScaledDotProductSoftmaxSetAttention()

# Initialize state
state = attention.initial_state()

# Process query and key-value pairs
query = torch.randn(2, 5, 32)
key = torch.randn(2, 10, 32)
value = torch.randn(2, 10, 32)

# Update state with first batch
state = attention.compute_aggregated_attention(query, key, value, None, state)

# Update state with second batch
key2 = torch.randn(2, 8, 32)
value2 = torch.randn(2, 8, 32)
state = attention.compute_aggregated_attention(query, key2, value2, None, state)

# Get final result
output = attention.get(state)

Multiset Processing

import torch
from ust.modules import SelfSetAttentionBlock

# Create a self-attention block
self_attn = SelfSetAttentionBlock(
    d_model=32,
    nhead=4,
    dim_feedforward=64
)

# Process a set with duplicates
x = torch.randn(2, 10, 32)
multiplicities = torch.tensor([
    [1, 2, 1, 3, 1, 1, 2, 1, 1, 1],
    [2, 1, 1, 1, 3, 2, 1, 1, 1, 2]
]).float()

# Process with multiplicities
output = self_attn(x, multiplicities)

License

This project is licensed under the MIT License - see the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

universal_set_transformers-0.3.0.tar.gz (26.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

universal_set_transformers-0.3.0-py3-none-any.whl (25.2 kB view details)

Uploaded Python 3

File details

Details for the file universal_set_transformers-0.3.0.tar.gz.

File metadata

File hashes

Hashes for universal_set_transformers-0.3.0.tar.gz
Algorithm Hash digest
SHA256 db8a6f5ea94fbb54e583a5d564b293c31aed44c708fcad381952404bd1bdd846
MD5 3f6f45cb9f9fbb3ad41f088539e990bd
BLAKE2b-256 2626c1a2d9dcb30a2d22141a8f9a1434c2dfb14763b6c2db6a3140e22988c0a0

See more details on using hashes here.

File details

Details for the file universal_set_transformers-0.3.0-py3-none-any.whl.

File metadata

File hashes

Hashes for universal_set_transformers-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d4c8d39d1bf2f72bced3b34af3ff028a8261e90a73b2fb25e8816e41915c106f
MD5 c9fabf188651af5e69e95aecfd30bb5b
BLAKE2b-256 c797c78d13cf8e6d95a3b7ea51987e06a0da259bc5d32ba9f0f097c5f8a17a3c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page