Skip to main content

Universal Set Transformers for scalable set/multiset modeling

Project description

Universal Set Transformers

License: MIT Python 3.8+ PyTorch 2.5+

Universal Set Transformers (UST) is a PyTorch library for processing sets and multisets with transformer-based architectures. This library provides efficient implementations of set attention mechanisms that are permutation invariant and can handle sets of varying sizes.

Features

  • Set and Multiset Processing: Process sets and multisets with explicit multiplicity handling
  • Minibatch Consistency: Process large sets in minibatches with guaranteed consistency
  • Efficient Attention Mechanisms: Multiple attention implementations (softmax, sigmoid, flash attention)
  • State Management: Explicit state handling for incremental processing
  • Numerical Stability: Robust handling of large values and edge cases

Installation

pip install universal-set-transformers

Or install from source:

git clone https://github.com/yourusername/universal-set-transformers.git
cd universal-set-transformers
pip install -e .

Quick Start

Basic Usage

import torch
from ust.modules import SelfSetAttentionBlock, CrossSetAttentionBlock

# Create a self-attention block
self_attn = SelfSetAttentionBlock(
    d_model=32,
    nhead=4,
    dim_feedforward=64
)

# Process a set
x = torch.randn(2, 10, 32)  # batch_size=2, set_size=10, d_model=32
output = self_attn(x)

# Process a multiset with multiplicities
multiplicities = torch.randint(1, 5, (2, 10)).float()
output_with_mult = self_attn(x, multiplicities)

Cross-Set Attention

import torch
from ust.modules import CrossSetAttentionBlock

# Create a cross-attention block
cross_attn = CrossSetAttentionBlock(
    d_model=32,
    nhead=4,
    dim_feedforward=64
)

# Process two sets
x = torch.randn(2, 5, 32)   # batch_size=2, set_size=5, d_model=32
y = torch.randn(2, 10, 32)  # batch_size=2, set_size=10, d_model=32
output = cross_attn(x, y)

# Process with multiplicities
y_multiplicities = torch.randint(1, 5, (2, 10)).float()
output_with_mult = cross_attn(x, y, y_multiplicities)

Minibatch Processing

import torch
from ust.modules import CrossSetAttentionBlock

# Create a cross-attention block
cross_attn = CrossSetAttentionBlock(
    d_model=32,
    nhead=4,
    dim_feedforward=64
)

# Query set
x = torch.randn(2, 5, 32)  # batch_size=2, set_size=5, d_model=32

# Initialize state
state = cross_attn.attention.scaled_dot_product_attention.initial_state()

# Process first minibatch
y1 = torch.randn(2, 10, 32)
m1 = torch.randint(1, 5, (2, 10)).float()
output1, state = cross_attn(x, y1, m1, state=state)

# Process second minibatch
y2 = torch.randn(2, 8, 32)
m2 = torch.randint(1, 5, (2, 8)).float()
output2, state = cross_attn(x, y2, m2, state=state)

# Final output is output2

Architecture

The library is organized into several modules:

  • ust.modules: Core modules for set processing
    • scaled_dot_product_attention.py: Various attention mechanisms
    • set_transformer.py: Set Transformer implementation
  • ust.api: Abstract interfaces and type definitions
  • ust.utils: Utility functions

Key Components

Attention Mechanisms

  • ScaledDotProductSoftmaxSetAttention: Standard softmax attention for sets
  • ScaledDotProductSoftmaxFlashSetAttention: Optimized flash attention for sets
  • ScaledDotProductSigmoidSetAttention: Sigmoid attention for sets

Attention Blocks

  • MultiheadSetAttention: Multi-head attention for sets
  • SelfSetAttentionBlock: Self-attention block for sets
  • CrossSetAttentionBlock: Cross-attention block for sets

Set Transformer Components

  • InducedSetAttentionBlock: Induced set attention block
  • PoolingByMultiheadSetAttention: Pooling mechanism for sets

Advanced Features

Explicit State Handling

import torch
from ust.modules.scaled_dot_product_attention import ScaledDotProductSoftmaxSetAttention

# Create attention mechanism
attention = ScaledDotProductSoftmaxSetAttention()

# Initialize state
state = attention.initial_state()

# Process query and key-value pairs
query = torch.randn(2, 5, 32)
key = torch.randn(2, 10, 32)
value = torch.randn(2, 10, 32)

# Update state with first batch
state = attention.compute_aggregated_attention(query, key, value, None, state)

# Update state with second batch
key2 = torch.randn(2, 8, 32)
value2 = torch.randn(2, 8, 32)
state = attention.compute_aggregated_attention(query, key2, value2, None, state)

# Get final result
output = attention.get(state)

Multiset Processing

import torch
from ust.modules import SelfSetAttentionBlock

# Create a self-attention block
self_attn = SelfSetAttentionBlock(
    d_model=32,
    nhead=4,
    dim_feedforward=64
)

# Process a set with duplicates
x = torch.randn(2, 10, 32)
multiplicities = torch.tensor([
    [1, 2, 1, 3, 1, 1, 2, 1, 1, 1],
    [2, 1, 1, 1, 3, 2, 1, 1, 1, 2]
]).float()

# Process with multiplicities
output = self_attn(x, multiplicities)

License

This project is licensed under the MIT License - see the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

universal_set_transformers-0.1.0.tar.gz (24.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

universal_set_transformers-0.1.0-py3-none-any.whl (24.6 kB view details)

Uploaded Python 3

File details

Details for the file universal_set_transformers-0.1.0.tar.gz.

File metadata

File hashes

Hashes for universal_set_transformers-0.1.0.tar.gz
Algorithm Hash digest
SHA256 d7cb513b16bd31e3963d256af2b25c8612771873d13880e7ebf4e6ad1f4ad45b
MD5 294e92d34aa81fbbd622aa964a1a4fb6
BLAKE2b-256 e982483b3d44d49a6feed67b985f4702472955e90b874768edac23e00c0c1589

See more details on using hashes here.

File details

Details for the file universal_set_transformers-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for universal_set_transformers-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 23efbc079a3b32db9538be678f528266239d6db221b233e1b415cb3c460aff52
MD5 528ae3dc534ab79a6d0a5d1214cb4240
BLAKE2b-256 70cacf9673a233f89efe52def99d151d1367cc02dd801010c427a39fb466bb8a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page