Unofficial PyTorch implementation of TokenLearner by Google AI
Project description
tokenlearner-pytorch
Unofficial PyTorch implementation of TokenLearner
by Ryoo et al. from Google AI (abs
, pdf
)
Installation
You can install TokenLearner via pip
:
pip install tokenlearner-pytorch
Usage
You can access the TokenLearner
class from the tokenlearner_pytorch
package. You can use this layer with a Vision Transformer, MLPMixer, or Video Vision Transformer as done in the paper.
import torch
from tokenlearner_pytorch import TokenLearner
tklr = TokenLearner(S=8)
x = torch.rand(512, 32, 32, 3)
y = tklr(x) # [512, 8, 3]
You can also use TokenLearner
and TokenFuser
together with Multi-head Self-Attention as done in the paper:
import torch
import torch.nn as nn
from tokenlearner_pytorch import TokenLearner, TokenFuser
mhsa = nn.MultiheadAttention(3, 1)
tklr = TokenLearner(S=8)
tkfr = TokenFuser(H=32, W=32, C=3, S=8)
x = torch.rand(512, 32, 32, 3) # a batch of images
y = tklr(x)
y = y.view(8, 512, 3)
y, _ = mhsa(y, y, y) # ignore attn weights
y = y.view(512, 8, 3)
out = tkfr(y, x) # [512, 32, 23, 3]
TODO
- Add support for temporal dimension
T
- Implement
TokenFuser
withViT
- Implement
TokenFuser
withViViT
Contributions
If I've made any errors or you have any suggestions, feel free to raise an Issue or PR. All contributions welcome!!
License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file tokenlearner_pytorch-0.1.2.tar.gz
.
File metadata
- Download URL: tokenlearner_pytorch-0.1.2.tar.gz
- Upload date:
- Size: 2.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.8.2 pkginfo/1.8.2 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.9.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 34d85f1c6bb51a819de3a57523027e392979c059c4cf3256ea65937d78bb9fac |
|
MD5 | 5ab91817035065500e06ce166d64b20c |
|
BLAKE2b-256 | 6328cc99beabee4d164f31dc34db74740dac444aa75d841bb80392b9594d45c3 |
File details
Details for the file tokenlearner_pytorch-0.1.2-py3-none-any.whl
.
File metadata
- Download URL: tokenlearner_pytorch-0.1.2-py3-none-any.whl
- Upload date:
- Size: 4.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.8.2 pkginfo/1.8.2 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.9.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 09b56676d31c8d65cd227c9d611a0a54d64e2a62ad545916efc95e1042208949 |
|
MD5 | 7bf21d415deaa3fdf460cbba3c8fa7d6 |
|
BLAKE2b-256 | 8d588cabaf0482bc0f7344bb23f213c53b7ca578c6ce0752fa3856e1cddefd37 |