Attention Free Transformer - Pytorch
Project description
aft-pytorch
Unofficial PyTorch implementation of Attention Free Transformer's layers by Zhai, et al. [abs, pdf] from Apple Inc.
Installation
You can install aft-pytorch
via pip
:
pip install aft-pytorch
Usage
You can import the AFT-Full or AFT-Simple layer (as described in the paper) from the package like so:
AFTFull
from aft_pytorch import AFTFull
layer = AFTFull(
max_seqlen=20,
dim=512,
hidden_dim=64
)
# a batch of sequences with 10 timesteps of length 512 each
x = torch.rand(32, 10, 512)
y = layer(x) # [32, 10, 512]
AFTSimple
from aft_pytorch import AFTSimple
layer = AFTSimple(
max_seqlen=20,
dim=512,
hidden_dim=64
)
# a batch of sequences with 10 timesteps of length 512 each
x = torch.rand(32, 10, 512)
y = layer(x) # [32, 10, 512]
AFTLocal
from aft_pytorch import AFTLocal
layer = AFTLocal(
max_seqlen=20,
dim=512,
hidden_dim=64
)
# a batch of sequences with 10 timesteps of length 512 each
x = torch.rand(32, 10, 512)
y = layer(x) # [32, 10, 512]
This layer wrapper is a 'plug-and-play' with your existing networks / Transformers. You can swap out the Self-Attention layer with the available layers in this package with minimal changes.
TODO
- Add full AFT architecture
- Add variants like,
AFTConv
,AFTLocal
Contributing
If you like this repo, please leave a star! If there are any amends or suggestions, feel free to raise a PR/issue.
Credits
@misc{attention-free-transformer,
title = {An Attention Free Transformer},
author = {Shuangfei Zhai and Walter Talbott and Nitish Srivastava and Chen Huang and Hanlin Goh and Ruixiang Zhang and Josh Susskind},
year = {2021},
URL = {https://arxiv.org/pdf/2105.14103.pdf}
}
License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
aft_pytorch-0.2.2.tar.gz
(3.8 kB
view hashes)
Built Distribution
Close
Hashes for aft_pytorch-0.2.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 863fcd6d52f82e5d876adfc4f0072ce97ea6793c7a121aca3babc1ddd752f08d |
|
MD5 | cecb34d1b1fe6b1a89663c4ccd0241b1 |
|
BLAKE2b-256 | 329ac5f1f35a6425e36a050c0c8c1a533208be418edf27cbdce33e81f1e10e2b |