Skip to main content

Universal Set Transformers for scalable set/multiset modeling

Project description

Universal Set Transformers

License: MIT Python 3.8+ PyTorch 2.5+

Universal Set Transformers (UST) is a PyTorch library for processing sets and multisets with transformer-based architectures. This library provides efficient implementations of set attention mechanisms that are permutation invariant and can handle sets of varying sizes.

Features

  • Set and Multiset Processing: Process sets and multisets with explicit multiplicity handling
  • Minibatch Consistency: Process large sets in minibatches with guaranteed consistency
  • Efficient Attention Mechanisms: Multiple attention implementations (softmax, sigmoid, flash attention)
  • State Management: Explicit state handling for incremental processing
  • Numerical Stability: Robust handling of large values and edge cases

Installation

pip install universal-set-transformers

Or install from source:

git clone https://github.com/yourusername/universal-set-transformers.git
cd universal-set-transformers
pip install -e .

Quick Start

Basic Usage

import torch
from ust.modules import SelfSetAttentionBlock, CrossSetAttentionBlock

# Create a self-attention block
self_attn = SelfSetAttentionBlock(
    d_model=32,
    nhead=4,
    dim_feedforward=64
)

# Process a set
x = torch.randn(2, 10, 32)  # batch_size=2, set_size=10, d_model=32
output = self_attn(x)

# Process a multiset with multiplicities
multiplicities = torch.randint(1, 5, (2, 10)).float()
output_with_mult = self_attn(x, multiplicities)

Cross-Set Attention

import torch
from ust.modules import CrossSetAttentionBlock

# Create a cross-attention block
cross_attn = CrossSetAttentionBlock(
    d_model=32,
    nhead=4,
    dim_feedforward=64
)

# Process two sets
x = torch.randn(2, 5, 32)   # batch_size=2, set_size=5, d_model=32
y = torch.randn(2, 10, 32)  # batch_size=2, set_size=10, d_model=32
output = cross_attn(x, y)

# Process with multiplicities
y_multiplicities = torch.randint(1, 5, (2, 10)).float()
output_with_mult = cross_attn(x, y, y_multiplicities)

Minibatch Processing

import torch
from ust.modules import CrossSetAttentionBlock

# Create a cross-attention block
cross_attn = CrossSetAttentionBlock(
    d_model=32,
    nhead=4,
    dim_feedforward=64
)

# Query set
x = torch.randn(2, 5, 32)  # batch_size=2, set_size=5, d_model=32

# Initialize state
state = cross_attn.attention.scaled_dot_product_attention.initial_state()

# Process first minibatch
y1 = torch.randn(2, 10, 32)
m1 = torch.randint(1, 5, (2, 10)).float()
output1, state = cross_attn(x, y1, m1, state=state)

# Process second minibatch
y2 = torch.randn(2, 8, 32)
m2 = torch.randint(1, 5, (2, 8)).float()
output2, state = cross_attn(x, y2, m2, state=state)

# Final output is output2

Architecture

The library is organized into several modules:

  • ust.modules: Core modules for set processing
    • scaled_dot_product_attention.py: Various attention mechanisms
    • set_transformer.py: Set Transformer implementation
  • ust.api: Abstract interfaces and type definitions
  • ust.utils: Utility functions

Key Components

Attention Mechanisms

  • ScaledDotProductSoftmaxSetAttention: Standard softmax attention for sets
  • ScaledDotProductSoftmaxFlashSetAttention: Optimized flash attention for sets
  • ScaledDotProductSigmoidSetAttention: Sigmoid attention for sets

Attention Blocks

  • MultiheadSetAttention: Multi-head attention for sets
  • SelfSetAttentionBlock: Self-attention block for sets
  • CrossSetAttentionBlock: Cross-attention block for sets

Set Transformer Components

  • InducedSetAttentionBlock: Induced set attention block
  • PoolingByMultiheadSetAttention: Pooling mechanism for sets

Advanced Features

Explicit State Handling

import torch
from ust.modules.scaled_dot_product_attention import ScaledDotProductSoftmaxSetAttention

# Create attention mechanism
attention = ScaledDotProductSoftmaxSetAttention()

# Initialize state
state = attention.initial_state()

# Process query and key-value pairs
query = torch.randn(2, 5, 32)
key = torch.randn(2, 10, 32)
value = torch.randn(2, 10, 32)

# Update state with first batch
state = attention.compute_aggregated_attention(query, key, value, None, state)

# Update state with second batch
key2 = torch.randn(2, 8, 32)
value2 = torch.randn(2, 8, 32)
state = attention.compute_aggregated_attention(query, key2, value2, None, state)

# Get final result
output = attention.get(state)

Multiset Processing

import torch
from ust.modules import SelfSetAttentionBlock

# Create a self-attention block
self_attn = SelfSetAttentionBlock(
    d_model=32,
    nhead=4,
    dim_feedforward=64
)

# Process a set with duplicates
x = torch.randn(2, 10, 32)
multiplicities = torch.tensor([
    [1, 2, 1, 3, 1, 1, 2, 1, 1, 1],
    [2, 1, 1, 1, 3, 2, 1, 1, 1, 2]
]).float()

# Process with multiplicities
output = self_attn(x, multiplicities)

License

This project is licensed under the MIT License - see the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

universal_set_transformers-0.2.0.tar.gz (26.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

universal_set_transformers-0.2.0-py3-none-any.whl (25.0 kB view details)

Uploaded Python 3

File details

Details for the file universal_set_transformers-0.2.0.tar.gz.

File metadata

File hashes

Hashes for universal_set_transformers-0.2.0.tar.gz
Algorithm Hash digest
SHA256 a41a5b9c4e9aaf4511fc2c4809670873ce24415934ba9ba4d29929f040db1362
MD5 94bd5b85350ad3ea7b11e48ed8e3d270
BLAKE2b-256 16f80621cf92c3792af99d2fc7b62221456b6427abab9d6df16744292bbd3323

See more details on using hashes here.

File details

Details for the file universal_set_transformers-0.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for universal_set_transformers-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 2fd4b800db754d5abd75a642a7ab5b1ecebc09cafb7b1cadc4f98d6dbe3ce86c
MD5 fed7bb750866050a7c48a615574e1f05
BLAKE2b-256 66062b46beb2f5d22a33d9a4c7fa24b12313b6318d7b3122fa175c4ca504a42e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page