PyTorch implementation of Area Attention
Project description
Area Attention
PyTorch implementation of Area Attention [1]. This module allows to attend to areas of the memory, where each area contains a group of items that are either spatially or temporally adjacent. TensorFlow implementation can be found here.
Setup
$ pip install area_attention
Usage
Single-head Area Attention:
import torch
from area_attention import AreaAttention
area_attention = AreaAttention(
key_query_size=32,
area_key_mode='max',
area_value_mode='mean',
max_area_height=2,
max_area_width=2,
memory_height=4,
memory_width=4,
dropout_rate=0.2,
top_k_areas=0
)
q = torch.rand(4, 8, 32)
k = torch.rand(4, 16, 32)
v = torch.rand(4, 16, 64)
x = area_attention(q, k, v)
x # torch.Tensor with shape (8, 64)
Multi-head Area Attention:
import torch
from area_attention import AreaAttention, MultiHeadAreaAttention
area_attention = AreaAttention(
key_query_size=32,
area_key_mode='max',
area_value_mode='mean',
max_area_height=2,
max_area_width=2,
memory_height=4,
memory_width=4,
dropout_rate=0.2,
top_k_areas=0
)
multi_head_area_attention = MultiHeadAreaAttention(
area_attention=area_attention,
num_heads=2,
key_query_size=32,
key_query_size_hidden=32,
value_size=64,
value_size_hidden=64
)
q = torch.rand(4, 8, 32)
k = torch.rand(4, 16, 32)
v = torch.rand(4, 16, 64)
x = multi_head_area_attention(q, k, v)
x # torch.Tensor with shape (8, 64)
Unit tests
$ python -m pytest tests
Bibliography
[1] Li, Yang, et al. "Area attention." International Conference on Machine Learning. PMLR, 2019.
Citations
@inproceedings{li2019area,
title={Area attention},
author={Li, Yang and Kaiser, Lukasz and Bengio, Samy and Si, Si},
booktitle={International Conference on Machine Learning},
pages={3846--3855},
year={2019},
organization={PMLR}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
area_attention-0.1.0.tar.gz
(6.0 kB
view details)
Built Distribution
File details
Details for the file area_attention-0.1.0.tar.gz
.
File metadata
- Download URL: area_attention-0.1.0.tar.gz
- Upload date:
- Size: 6.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.25.0 setuptools/49.2.1 requests-toolbelt/0.9.1 tqdm/4.53.0 CPython/3.9.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | af249a0ec673338efbe4eecdfa7d131cd46f55fecd4ba8930e259d60f01b7ccb |
|
MD5 | 8a3c8c59b0e6e21b5df9ea6581bc564b |
|
BLAKE2b-256 | 91d699c8bd3ba753a55adcfc234100adf0f4515a74414af5135f2c9c7ec93bf6 |
File details
Details for the file area_attention-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: area_attention-0.1.0-py3-none-any.whl
- Upload date:
- Size: 6.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.25.0 setuptools/49.2.1 requests-toolbelt/0.9.1 tqdm/4.53.0 CPython/3.9.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6ae5457f229c2f28f6053016f65ea6b718f0c2dbbfded71d6df833c1383517ce |
|
MD5 | 8bb9d39e2dbb8caa44a2b4a7f16927c0 |
|
BLAKE2b-256 | 1c4095824f37771903f3a593ca3d844485d92fb8a18b93aa63c45fe88b61ef86 |