Unofficial PyTorch implementation of TokenLearner by Google AI
Project description
tokenlearner-pytorch
Unofficial PyTorch implementation of TokenLearner by Ryoo et al. from Google AI (abs, pdf)
Installation
You can install TokenLearner via pip:
pip install tokenlearner-pytorch
Usage
You can access the TokenLearner class from the tokenlearner_pytorch package. You can use this layer with a Vision Transformer, MLPMixer, or Video Vision Transformer as done in the paper.
import torch
from tokenlearner_pytorch import TokenLearner
tklr = TokenLearner(S=8)
x = torch.rand(512, 32, 32, 3)
y = tklr(x) # [512, 8, 3]
You can also use TokenLearner and TokenFuser together with Multi-head Self-Attention as done in the paper:
import torch
import torch.nn as nn
from tokenlearner_pytorch import TokenLearner, TokenFuser
mhsa = nn.MultiheadAttention(3, 1)
tklr = TokenLearner(S=8)
tkfr = TokenFuser(H=32, W=32, C=3, S=8)
x = torch.rand(512, 32, 32, 3) # a batch of images
y = tklr(x)
y = y.view(8, 512, 3)
y, _ = mhsa(y, y, y) # ignore attn weights
y = y.view(512, 8, 3)
out = tkfr(y, x) # [512, 32, 23, 3]
TODO
- Add support for temporal dimension
T - Implement
TokenFuserwithViT - Implement
TokenFuserwithViViT
Contributions
If I've made any errors or you have any suggestions, feel free to raise an Issue or PR. All contributions welcome!!
License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file tokenlearner_pytorch-0.1.2.tar.gz.
File metadata
- Download URL: tokenlearner_pytorch-0.1.2.tar.gz
- Upload date:
- Size: 2.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.8.2 pkginfo/1.8.2 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.9.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
34d85f1c6bb51a819de3a57523027e392979c059c4cf3256ea65937d78bb9fac
|
|
| MD5 |
5ab91817035065500e06ce166d64b20c
|
|
| BLAKE2b-256 |
6328cc99beabee4d164f31dc34db74740dac444aa75d841bb80392b9594d45c3
|
File details
Details for the file tokenlearner_pytorch-0.1.2-py3-none-any.whl.
File metadata
- Download URL: tokenlearner_pytorch-0.1.2-py3-none-any.whl
- Upload date:
- Size: 4.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.8.2 pkginfo/1.8.2 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.9.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
09b56676d31c8d65cd227c9d611a0a54d64e2a62ad545916efc95e1042208949
|
|
| MD5 |
7bf21d415deaa3fdf460cbba3c8fa7d6
|
|
| BLAKE2b-256 |
8d588cabaf0482bc0f7344bb23f213c53b7ca578c6ce0752fa3856e1cddefd37
|