Python package implementing the Dual Attention Transformer (DAT), as proposed by the paper "Disentangling and Integrating Relational and Sensory Information in Transformer Architectures" by Awni Altabaa, John Lafferty.
Project description
Dual Attention Transformer
This repository contains a Python package (hosted on PyPI) implementing the Dual Attention Transformer (DAT), as proposed by the paper Disentangling and Integrating Relational and Sensory Information in Transformer Architectures by Awni Altabaa, John Lafferty.
Abstract. The Transformer architecture processes sequences by implementing a form of neural message-passing that consists of iterative information retrieval (attention), followed by local processing (position-wise MLP). Two types of information are essential under this general computational paradigm: "sensory" information about individual objects, and "relational" information describing the relationships between objects. Standard attention naturally encodes the former, but does not explicitly encode the latter. In this paper, we present an extension of Transformers where multi-head attention is augmented with two distinct types of attention heads, each routing information of a different type. The first type is the standard attention mechanism of Transformers, which captures object-level features, while the second type is a novel attention mechanism we propose to explicitly capture relational information. The two types of attention heads each possess different inductive biases, giving the resulting architecture greater efficiency and versatility. The promise of this approach is demonstrated empirically across a range of tasks.
Please see the project webpage for an overview of the paper and its main results.
Installation
Via pip
pip install dual_attention
From source
git clone https://github.com/Awni00/dual-attention.git
cd dual-attention
pip install .
Requirements
torch
einops
tiktoken
[optional] (tokenizer used for pre-trained language models)huggingface_hub
[optional] (for loading pre-trained model checkpoints from HF)safetensors
[optional] (again, for loading pre-trained model checkpoints from HF)
Summary of Paper
The Transformer architecture can be understood as an instantiation of a broader computational paradigm implementing a form of neural message-passing that iterates between two operations: 1) information retrieval (self-attention), and 2) local processing (feedforward block). To process a sequence of objects $x_1, \ldots, x_n$, this general neural message-passing paradigm has the form
$$ \begin{align*} x_i &\gets \mathrm{Aggregate}(x_i, {\{m_{j \to i}\}}_{j=1}^n)\ x_i &\gets \mathrm{Process}(x_i). \end{align*} $$
In the case of Transformers, the self-attention mechanism can be seen as sending messages from object $j$ to object $i$ that are encodings of the sender's features, with the message from sender $j$ to receiver $i$ given by $m_{j \to i} = \phi_v(x_j)$. These messages are then aggregated according to some selection criterion based on the receiver's features, typically given by the softmax attention scores.
We posit that there are essentially two types of information that are essential under this general computational paradigm: 1) sensory information describing the features and attributes of individual objects, and relational information about the relationships between objects. The standard attention mechanism of Transformers naturally encodes the former, but does not explicitly encode the latter.
In this paper, we propose Relational Attention as a novel attention mechanism that enables routing of relational information between objects. We then introduce Dual Attention, a variant of multi-head attention combining two distinct attention mechanisms: 1) standard Self-Attention for routing sensory information, and 2) Relational Attention for routing relational information. This in turn defines an extension of the Transformer architecture with an explicit ability to reason over both types of information.
Outline of Main Modules
relational_attention.py
: This module implements Relational Attention, an attention mechanism for routing relational information between objects.symbol_retrieval.py
: This module implements different symbol assignment mechanisms used in relational attention, including symbolic attention, positional symbols, and position-relative symbols.dual_attention.py
: This module implements Dual Attention, a variant of multi-head attention combining two distinct attention mechanisms: standard Self-Attention for routing sensory information and Relational Attention for routing relational information.dual_attn_blocks.py
: This module implements Dual Attention variants of encoder and decoder Transformer blocks, which are used to build language models, seq2seq models, vision models, etc.transformer_blocks.py
: This module implements standard Transformer encoder and decoder blocks, and is used as a baseline in our experiments.language_models.py
: This module implements a Dual Attention Transformer language model (as well as a standard Transformer language model as a baseline).seq2seq_models.py
: This module implements a seq2seq encoder-decoder Dual Attention Transformer.vision_models.py
: This module implements a Vision Dual Attention Transformer model, in the style of a Vision Transformer (i.e., image is split up into patches and fed to an encoder).
Usage Examples
All layers and models are implemented in PyTorch as nn.Module
objects. Thus, the implemented modules are compatible with typical PyTorch workflows, training code, and packages like PyTorch Lightning/torchinfo/etc.
The following code demonstrates the creation of a Dual Attention Transformer Language Model.
from dual_attention.language_models import DualAttnTransformerLM
dat_lm = DualAttnTransformerLM(
vocab_size=32_000, # vocabulary size
d_model=512, # model dimension
n_layers=6, # number of layers
n_heads_sa=4, # number of self-attention heads
n_heads_ra=4, # number of relational attention headsd
dff=2048, # feedforward intermediate dimension
dropout_rate=0.1, # dropout rate
activation='swiglu', # activation function of feedforward block
norm_first=True, # whether to use pre-norm or post-norm
max_block_size=1024, # max context length
symbol_retrieval='symbolic_attention', # type of symbol assignment mechanism
symbol_retrieval_kwargs=dict(d_model=512, n_heads=4, n_symbols=512),
pos_enc_type='RoPE' # type of positional encoding to use
)
idx = torch.randint(0, 32_000, (1, 128+1))
x, y = idx[:, :-1], idx[:, 1:]
logits, loss = dat_lm(x, y)
logits # shape: (1, 128, 32000)
The following code demos the creation of a Vision Dual Attention Transformer model.
from dual_attention.vision_models import VisionDualAttnTransformer
img_shape = (3, 224, 224)
patch_size = (16, 16)
n_patches = (img_shape[1] // patch_size[0]) * (img_shape[2] // patch_size[1])
dat_vision = VisionDualAttnTransformer(
image_shape=img_shape, # shape of input image
patch_size=patch_size, # size of patch
num_classes=1000, # number of classes
d_model=512, # model dimension
n_layers=6, # number of layers
n_heads_sa=4, # number of self-attention heads
n_heads_ra=4, # number of relational attention heads
dff=2048, # feedforward intermediate dimension
dropout_rate=0.1, # dropout rate
activation='swiglu', # activation function of feedforward block
norm_first=True, # whether to use pre-norm or post-norm
symbol_retrieval='position_relative', # type of symbol assignment mechanism
symbol_retrieval_kwargs=dict(symbol_dim=512, max_rel_pos=n_patches+1),
ra_kwargs=dict(symmetric_rels=True, use_relative_positional_symbols=True),
pool='cls', # type of pooling (class token)
)
img = torch.randn(1, *img_shape)
logits = dat_vision(img)
logits.shape # shape: (1, 1000)
If you have questions, feel free to file an issue or send an email.
Citation
If you use natbib or bibtex please use the following citation (as provided by Google Scholar).
@article{altabaa2024disentangling,
title={Disentangling and Integrating Relational and Sensory Information in Transformer Architectures},
author={Awni Altabaa and John Lafferty},
year={2024},
journal={arXiv preprint arXiv:2402.08856}
}
If you use biblatex
, please use the following citation (as provided by arxiv).
@misc{altabaa2024disentangling,
title={Disentangling and Integrating Relational and Sensory Information in Transformer Architectures},
author={Awni Altabaa and John Lafferty},
year={2024},
eprint={2405.16727},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file dual_attention-0.0.7.tar.gz
.
File metadata
- Download URL: dual_attention-0.0.7.tar.gz
- Upload date:
- Size: 40.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | f9eec6e0d1eb56758558e1043421fa5106062a276d68ec86e8ad8d2e01055a3d |
|
MD5 | e82e63b3d0a648569a8a6d5c74ae1203 |
|
BLAKE2b-256 | 85f2e60534e28d5a74975bcba312bd6ea1f2e5cb2f10f3acfd62505872fcac16 |
File details
Details for the file dual_attention-0.0.7-py3-none-any.whl
.
File metadata
- Download URL: dual_attention-0.0.7-py3-none-any.whl
- Upload date:
- Size: 46.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8984d5afff45daa0e9799643b5e8223b08f893b214693ca0ee72cec3bdae7043 |
|
MD5 | 008a23f0b4fa5bd6ea25d0b923db235d |
|
BLAKE2b-256 | 1a38bb59d0689145dca08f000f296d9b2b775c6393d02ff51f4890618d8d2531 |