Skip to main content

Python package implementing the Dual Attention Transformer (DAT), as proposed by the paper "Disentangling and Integrating Relational and Sensory Information in Transformer Architectures" by Awni Altabaa, John Lafferty.

Project description

Dual Attention Transformer

arXiv Paper Project Webpage PyPI Documentation Status PR welcome PyPI - License

This repository contains a Python package (hosted on PyPI) implementing the Dual Attention Transformer (DAT), as proposed by the paper Disentangling and Integrating Relational and Sensory Information in Transformer Architectures by Awni Altabaa, John Lafferty.

Abstract. The Transformer architecture processes sequences by implementing a form of neural message-passing that consists of iterative information retrieval (attention), followed by local processing (position-wise MLP). Two types of information are essential under this general computational paradigm: "sensory" information about individual objects, and "relational" information describing the relationships between objects. Standard attention naturally encodes the former, but does not explicitly encode the latter. In this paper, we present an extension of Transformers where multi-head attention is augmented with two distinct types of attention heads, each routing information of a different type. The first type is the standard attention mechanism of Transformers, which captures object-level features, while the second type is a novel attention mechanism we propose to explicitly capture relational information. The two types of attention heads each possess different inductive biases, giving the resulting architecture greater efficiency and versatility. The promise of this approach is demonstrated empirically across a range of tasks.

Please see the project webpage for an overview of the paper and its main results.

Installation

Via pip

pip install dual_attention

From source

git clone https://github.com/Awni00/dual-attention.git
cd dual-attention
pip install .

Requirements

  • torch
  • einops
  • tiktoken [optional] (tokenizer used for pre-trained language models)
  • huggingface_hub [optional] (for loading pre-trained model checkpoints from HF)
  • safetensors [optional] (again, for loading pre-trained model checkpoints from HF)

Summary of Paper

The Transformer architecture can be understood as an instantiation of a broader computational paradigm implementing a form of neural message-passing that iterates between two operations: 1) information retrieval (self-attention), and 2) local processing (feedforward block). To process a sequence of objects $x_1, \ldots, x_n$, this general neural message-passing paradigm has the form

$$ \begin{align*} x_i &\gets \mathrm{Aggregate}(x_i, {\{m_{j \to i}\}}_{j=1}^n)\ x_i &\gets \mathrm{Process}(x_i). \end{align*} $$

In the case of Transformers, the self-attention mechanism can be seen as sending messages from object $j$ to object $i$ that are encodings of the sender's features, with the message from sender $j$ to receiver $i$ given by $m_{j \to i} = \phi_v(x_j)$. These messages are then aggregated according to some selection criterion based on the receiver's features, typically given by the softmax attention scores.

We posit that there are essentially two types of information that are essential under this general computational paradigm: 1) sensory information describing the features and attributes of individual objects, and relational information about the relationships between objects. The standard attention mechanism of Transformers naturally encodes the former, but does not explicitly encode the latter.

In this paper, we propose Relational Attention as a novel attention mechanism that enables routing of relational information between objects. We then introduce Dual Attention, a variant of multi-head attention combining two distinct attention mechanisms: 1) standard Self-Attention for routing sensory information, and 2) Relational Attention for routing relational information. This in turn defines an extension of the Transformer architecture with an explicit ability to reason over both types of information.

Outline of Main Modules

  • relational_attention.py: This module implements Relational Attention, an attention mechanism for routing relational information between objects.
  • symbol_retrieval.py: This module implements different symbol assignment mechanisms used in relational attention, including symbolic attention, positional symbols, and position-relative symbols.
  • dual_attention.py: This module implements Dual Attention, a variant of multi-head attention combining two distinct attention mechanisms: standard Self-Attention for routing sensory information and Relational Attention for routing relational information.
  • dual_attn_blocks.py: This module implements Dual Attention variants of encoder and decoder Transformer blocks, which are used to build language models, seq2seq models, vision models, etc.
  • transformer_blocks.py: This module implements standard Transformer encoder and decoder blocks, and is used as a baseline in our experiments.
  • language_models.py: This module implements a Dual Attention Transformer language model (as well as a standard Transformer language model as a baseline).
  • seq2seq_models.py: This module implements a seq2seq encoder-decoder Dual Attention Transformer.
  • vision_models.py: This module implements a Vision Dual Attention Transformer model, in the style of a Vision Transformer (i.e., image is split up into patches and fed to an encoder).

Usage Examples

All layers and models are implemented in PyTorch as nn.Module objects. Thus, the implemented modules are compatible with typical PyTorch workflows, training code, and packages like PyTorch Lightning/torchinfo/etc.

The following code demonstrates the creation of a Dual Attention Transformer Language Model.

from dual_attention.language_models import DualAttnTransformerLM

dat_lm = DualAttnTransformerLM(
    vocab_size=32_000,    # vocabulary size
    d_model=512,          # model dimension
    n_layers=6,           # number of layers
    n_heads_sa=4,         # number of self-attention heads
    n_heads_ra=4,         # number of relational attention headsd
    dff=2048,             # feedforward intermediate dimension
    dropout_rate=0.1,     # dropout rate
    activation='swiglu',  # activation function of feedforward block
    norm_first=True,      # whether to use pre-norm or post-norm
    max_block_size=1024,  # max context length
    symbol_retrieval='symbolic_attention', # type of symbol assignment mechanism
    symbol_retrieval_kwargs=dict(d_model=512, n_heads=4, n_symbols=512),
    pos_enc_type='RoPE'   # type of positional encoding to use
)

idx = torch.randint(0, 32_000, (1, 128+1))
x, y = idx[:, :-1], idx[:, 1:]
logits, loss = dat_lm(x, y)
logits # shape: (1, 128, 32000)

The following code demos the creation of a Vision Dual Attention Transformer model.

from dual_attention.vision_models import VisionDualAttnTransformer

img_shape = (3, 224, 224)
patch_size = (16, 16)
n_patches = (img_shape[1] // patch_size[0]) * (img_shape[2] // patch_size[1])


dat_vision = VisionDualAttnTransformer(
    image_shape=img_shape,     # shape of input image
    patch_size=patch_size,     # size of patch
    num_classes=1000,          # number of classes
    d_model=512,               # model dimension
    n_layers=6,                # number of layers
    n_heads_sa=4,              # number of self-attention heads
    n_heads_ra=4,              # number of relational attention heads
    dff=2048,                  # feedforward intermediate dimension
    dropout_rate=0.1,          # dropout rate
    activation='swiglu',       # activation function of feedforward block
    norm_first=True,           # whether to use pre-norm or post-norm
    symbol_retrieval='position_relative', # type of symbol assignment mechanism
    symbol_retrieval_kwargs=dict(symbol_dim=512, max_rel_pos=n_patches+1),
    ra_kwargs=dict(symmetric_rels=True, use_relative_positional_symbols=True),
    pool='cls',                # type of pooling (class token)
)

img = torch.randn(1, *img_shape)
logits = dat_vision(img)
logits.shape # shape: (1, 1000)

If you have questions, feel free to file an issue or send an email.

Citation

If you use natbib or bibtex please use the following citation (as provided by Google Scholar).

@article{altabaa2024disentangling,
    title={Disentangling and Integrating Relational and Sensory Information in Transformer Architectures},
    author={Awni Altabaa and John Lafferty},
    year={2024},
    journal={arXiv preprint arXiv:2402.08856}
}

If you use biblatex, please use the following citation (as provided by arxiv).

@misc{altabaa2024disentangling,
    title={Disentangling and Integrating Relational and Sensory Information in Transformer Architectures},
    author={Awni Altabaa and John Lafferty},
    year={2024},
    eprint={2405.16727},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dual_attention-0.0.7.tar.gz (40.4 kB view details)

Uploaded Source

Built Distribution

dual_attention-0.0.7-py3-none-any.whl (46.8 kB view details)

Uploaded Python 3

File details

Details for the file dual_attention-0.0.7.tar.gz.

File metadata

  • Download URL: dual_attention-0.0.7.tar.gz
  • Upload date:
  • Size: 40.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.9

File hashes

Hashes for dual_attention-0.0.7.tar.gz
Algorithm Hash digest
SHA256 f9eec6e0d1eb56758558e1043421fa5106062a276d68ec86e8ad8d2e01055a3d
MD5 e82e63b3d0a648569a8a6d5c74ae1203
BLAKE2b-256 85f2e60534e28d5a74975bcba312bd6ea1f2e5cb2f10f3acfd62505872fcac16

See more details on using hashes here.

File details

Details for the file dual_attention-0.0.7-py3-none-any.whl.

File metadata

File hashes

Hashes for dual_attention-0.0.7-py3-none-any.whl
Algorithm Hash digest
SHA256 8984d5afff45daa0e9799643b5e8223b08f893b214693ca0ee72cec3bdae7043
MD5 008a23f0b4fa5bd6ea25d0b923db235d
BLAKE2b-256 1a38bb59d0689145dca08f000f296d9b2b775c6393d02ff51f4890618d8d2531

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page