Skip to main content

Axial Positional Embedding

Project description

Axial Positional Embedding

PyPI version

A type of positional embedding that is very effective when working with attention networks on multi-dimensional data, or for language models in general.

Install

$ pip install axial-positional-embedding

Usage

import torch
from axial_positional_embedding import AxialPositionalEmbedding

pos_emb = AxialPositionalEmbedding(
    dim = 512,
    axial_shape = (64, 64),          # axial shape will multiply up to the maximum sequence length allowed (64 * 64 = 4096)
    axial_dims = (256, 256)          # if not specified, dimensions will default to 'dim' for all axials and summed at the end. if specified, each axial will have the specified dimension and be concatted together. the concatted dimensions needs to sum up to the `dim` (256 + 256 = 512)
)

tokens = torch.randn(1, 1024, 512)  # assume are tokens
tokens = pos_emb(tokens) + tokens   # add positional embedding to token embeddings

A continuous version with better extrapolation ability (each axis parameterized by a 2 layer MLP)

import torch
from axial_positional_embedding import ContinuousAxialPositionalEmbedding

pos_emb = ContinuousAxialPositionalEmbedding(
    dim = 512,
    num_axial_dims = 3
)

tokens = torch.randn(1, 8, 16, 32, 512) # say a video with 8 frames, 16 x 32 image dimension

axial_pos_emb = pos_emb((8, 16, 32)) # pass in the size from above

tokens = axial_pos_emb + tokens   # add positional embedding to token embeddings

Citations

@inproceedings{kitaev2020reformer,
    title       = {Reformer: The Efficient Transformer},
    author      = {Nikita Kitaev and Lukasz Kaiser and Anselm Levskaya},
    booktitle   = {International Conference on Learning Representations},
    year        = {2020},
    url         = {https://openreview.net/forum?id=rkgNKkHtvB}
}
@misc{ho2019axial,
    title   = {Axial Attention in Multidimensional Transformers},
    author  = {Jonathan Ho and Nal Kalchbrenner and Dirk Weissenborn and Tim Salimans},
    year    = {2019},
    archivePrefix = {arXiv}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

axial_positional_embedding-0.3.12.tar.gz (6.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

axial_positional_embedding-0.3.12-py3-none-any.whl (6.7 kB view details)

Uploaded Python 3

File details

Details for the file axial_positional_embedding-0.3.12.tar.gz.

File metadata

File hashes

Hashes for axial_positional_embedding-0.3.12.tar.gz
Algorithm Hash digest
SHA256 10d98f342f8fa0d4d3d20b36ed27fb0fc63f1d398202b9165e07bc9d6130bc17
MD5 d0e0374f05e17859daaf31a4b660381a
BLAKE2b-256 12e6294c09aa82ccfb8ea18276cbce61964f35758f5de28f2677acf4fd71fcce

See more details on using hashes here.

File details

Details for the file axial_positional_embedding-0.3.12-py3-none-any.whl.

File metadata

File hashes

Hashes for axial_positional_embedding-0.3.12-py3-none-any.whl
Algorithm Hash digest
SHA256 f5708106b18903c3b99aced0bfa79d1ae6f708dae34f56e641e8c24dbc06b02f
MD5 df620fb2bc1b8d5ca7553015b8bc5a7b
BLAKE2b-256 e7f925753b9de3029d3eb2487755520b98eb72b0cb562d8974329c6e19831063

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page