Skip to main content

Unofficial PyTorch implementation of TokenLearner by Google AI

Project description

tokenlearner-pytorch

Unofficial PyTorch implementation of TokenLearner by Ryoo et al. from Google AI (abs, pdf)

Installation

You can install TokenLearner via pip:

pip install tokenlearner-pytorch

Usage

You can access the TokenLearner class from the tokenlearner_pytorch package. You can use this layer with a Vision Transformer, MLPMixer, or Video Vision Transformer as done in the paper.

import torch
from tokenlearner_pytorch import TokenLearner

tklr = TokenLearner(S=8)
x = torch.rand(512, 32, 32, 3)
y = tklr(x) # [512, 8, 3]

You can also use TokenLearner and TokenFuser together with Multi-head Self-Attention as done in the paper:

import torch
import torch.nn as nn
from tokenlearner_pytorch import TokenLearner, TokenFuser

mhsa = nn.MultiheadAttention(3, 1)
tklr = TokenLearner(S=8)
tkfr = TokenFuser(H=32, W=32, C=3, S=8)

x = torch.rand(512, 32, 32, 3) # a batch of images

y = tklr(x)
y = y.view(8, 512, 3)
y, _ = mhsa(y, y, y) # ignore attn weights
y = y.view(512, 8, 3)

out = tkfr(y, x) # [512, 32, 23, 3]

TODO

  • Add support for temporal dimension T
  • Implement TokenFuser with ViT
  • Implement TokenFuser with ViViT

Contributions

If I've made any errors or you have any suggestions, feel free to raise an Issue or PR. All contributions welcome!!

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tokenlearner_pytorch-0.1.2.tar.gz (2.9 kB view details)

Uploaded Source

Built Distribution

tokenlearner_pytorch-0.1.2-py3-none-any.whl (4.0 kB view details)

Uploaded Python 3

File details

Details for the file tokenlearner_pytorch-0.1.2.tar.gz.

File metadata

  • Download URL: tokenlearner_pytorch-0.1.2.tar.gz
  • Upload date:
  • Size: 2.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.8.2 pkginfo/1.8.2 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.9.5

File hashes

Hashes for tokenlearner_pytorch-0.1.2.tar.gz
Algorithm Hash digest
SHA256 34d85f1c6bb51a819de3a57523027e392979c059c4cf3256ea65937d78bb9fac
MD5 5ab91817035065500e06ce166d64b20c
BLAKE2b-256 6328cc99beabee4d164f31dc34db74740dac444aa75d841bb80392b9594d45c3

See more details on using hashes here.

File details

Details for the file tokenlearner_pytorch-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: tokenlearner_pytorch-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 4.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.8.2 pkginfo/1.8.2 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.9.5

File hashes

Hashes for tokenlearner_pytorch-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 09b56676d31c8d65cd227c9d611a0a54d64e2a62ad545916efc95e1042208949
MD5 7bf21d415deaa3fdf460cbba3c8fa7d6
BLAKE2b-256 8d588cabaf0482bc0f7344bb23f213c53b7ca578c6ce0752fa3856e1cddefd37

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page