PyTorch implementation of Area Attention
Project description
Area Attention
PyTorch implementation of Area Attention [1]. This module allows to attend to areas of the memory, where each area contains a group of items that are either spatially or temporally adjacent. TensorFlow implementation can be found here.
Setup
$ pip install area_attention
Usage
Single-head Area Attention:
import torch
from area_attention import AreaAttention
area_attention = AreaAttention(
key_query_size=32,
area_key_mode='max',
area_value_mode='mean',
max_area_height=2,
max_area_width=2,
memory_height=4,
memory_width=4,
dropout_rate=0.2,
top_k_areas=0
)
q = torch.rand(4, 8, 32)
k = torch.rand(4, 16, 32)
v = torch.rand(4, 16, 64)
x = area_attention(q, k, v)
x # torch.Tensor with shape (8, 64)
Multi-head Area Attention:
import torch
from area_attention import AreaAttention, MultiHeadAreaAttention
area_attention = AreaAttention(
key_query_size=32,
area_key_mode='max',
area_value_mode='mean',
max_area_height=2,
max_area_width=2,
memory_height=4,
memory_width=4,
dropout_rate=0.2,
top_k_areas=0
)
multi_head_area_attention = MultiHeadAreaAttention(
area_attention=area_attention,
num_heads=2,
key_query_size=32,
key_query_size_hidden=32,
value_size=64,
value_size_hidden=64
)
q = torch.rand(4, 8, 32)
k = torch.rand(4, 16, 32)
v = torch.rand(4, 16, 64)
x = multi_head_area_attention(q, k, v)
x # torch.Tensor with shape (8, 64)
Unit tests
$ python -m pytest tests
Bibliography
[1] Li, Yang, et al. "Area attention." International Conference on Machine Learning. PMLR, 2019.
Citations
@inproceedings{li2019area,
title={Area attention},
author={Li, Yang and Kaiser, Lukasz and Bengio, Samy and Si, Si},
booktitle={International Conference on Machine Learning},
pages={3846--3855},
year={2019},
organization={PMLR}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
area_attention-0.1.0.tar.gz
(6.0 kB
view hashes)
Built Distribution
Close
Hashes for area_attention-0.1.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6ae5457f229c2f28f6053016f65ea6b718f0c2dbbfded71d6df833c1383517ce |
|
MD5 | 8bb9d39e2dbb8caa44a2b4a7f16927c0 |
|
BLAKE2b-256 | 1c4095824f37771903f3a593ca3d844485d92fb8a18b93aa63c45fe88b61ef86 |