Skip to main content

PyTorch implementation of Area Attention

Project description

image

Area Attention

PyTorch implementation of Area Attention [1]. This module allows to attend to areas of the memory, where each area contains a group of items that are either spatially or temporally adjacent. TensorFlow implementation can be found here.

Setup

$ pip install area_attention

Usage

Single-head Area Attention:

import torch

from area_attention import AreaAttention

area_attention = AreaAttention(
    key_query_size=32,
    area_key_mode='max',
    area_value_mode='mean',
    max_area_height=2,
    max_area_width=2,
    memory_height=4,
    memory_width=4,
    dropout_rate=0.2,
    top_k_areas=0
)
q = torch.rand(4, 8, 32)
k = torch.rand(4, 16, 32)
v = torch.rand(4, 16, 64)
x = area_attention(q, k, v)
x  # torch.Tensor with shape (8, 64)

Multi-head Area Attention:

import torch

from area_attention import AreaAttention, MultiHeadAreaAttention

area_attention = AreaAttention(
    key_query_size=32,
    area_key_mode='max',
    area_value_mode='mean',
    max_area_height=2,
    max_area_width=2,
    memory_height=4,
    memory_width=4,
    dropout_rate=0.2,
    top_k_areas=0
)
multi_head_area_attention = MultiHeadAreaAttention(
    area_attention=area_attention,
    num_heads=2,
    key_query_size=32,
    key_query_size_hidden=32,
    value_size=64,
    value_size_hidden=64
)
q = torch.rand(4, 8, 32)
k = torch.rand(4, 16, 32)
v = torch.rand(4, 16, 64)
x = multi_head_area_attention(q, k, v)
x  # torch.Tensor with shape (8, 64)

Unit tests

$ python -m pytest tests

Bibliography

[1] Li, Yang, et al. "Area attention." International Conference on Machine Learning. PMLR, 2019.

Citations

@inproceedings{li2019area,
  title={Area attention},
  author={Li, Yang and Kaiser, Lukasz and Bengio, Samy and Si, Si},
  booktitle={International Conference on Machine Learning},
  pages={3846--3855},
  year={2019},
  organization={PMLR}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

area_attention-0.1.0.tar.gz (6.0 kB view hashes)

Uploaded Source

Built Distribution

area_attention-0.1.0-py3-none-any.whl (6.9 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page