An Implementation of of Transformer in Transformer for image classification, attention inside local patches
Project description
Transformer-in-Transformer
An Implementation of the Transformer in Transformer paper by Han et al. for image classification, attention inside local patches. Transformer in Transformer uses pixel level attention paired with patch level attention for image classification, in TensorFlow.
Installation
Run the following to install:
pip install tnt-tensorflow
Developing tnt-tensorflow
To install tnt-tensorflow
, along with tools you need to develop and test, run the following in your virtualenv:
git clone https://github.com/Rishit-dagli/Transformer-in-Transformer.git
# or clone your own fork
cd tnt
pip install -e .[dev]
To run rank and shape tests run the following:
pytest -v --disable-warnings --cov
Usage
import tensorflow as tf
from tnt import TNT
tnt = TNT(
image_size=256, # size of image
patch_dim=512, # dimension of patch token
pixel_dim=24, # dimension of pixel token
patch_size=16, # patch size
pixel_size=4, # pixel size
depth=5, # depth
num_classes=1000, # output number of classes
attn_dropout=0.1, # attention dropout
ff_dropout=0.1, # feedforward dropout
)
img = tf.random.uniform(shape=[5, 3, 256, 256])
logits = tnt(img) # (5, 1000)
An end to end training example for image classification on a dataset can be found in the training.ipynb notebook.
Pre-trained model
The pre-trained model for TNT-S variant (reproducing the paper results, 81.4% top-1 accuracy and 95.7% top-5 accuracy on ImageNet-1K) can also be found paired with an example of inferencing with it.
Model | TensorFlow Hub | Inference Tutorial |
---|---|---|
bucket | tfhub.dev |
Want to Contribute 🙋♂️?
Awesome! If you want to contribute to this project, you're always welcome! See Contributing Guidelines. You can also take a look at open issues for getting more information about current or upcoming tasks.
Want to discuss? 💬
Have any questions, doubts or want to present your opinions, views? You're always welcome. You can start discussions.
Citation
@misc{han2021transformer,
title={Transformer in Transformer},
author={Kai Han and An Xiao and Enhua Wu and Jianyuan Guo and Chunjing Xu and Yunhe Wang},
year={2021},
eprint={2103.00112},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
License
Copyright 2020 Rishit Dagli
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file tnt-tensorflow-0.2.0.tar.gz
.
File metadata
- Download URL: tnt-tensorflow-0.2.0.tar.gz
- Upload date:
- Size: 13.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.10.1 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | be3305a09eb04350ea1e706c52ffd0e388fa4a92c2658aa42d32181b7c9c72a1 |
|
MD5 | 19a5b6e51e26214afe166d7eb05ec2fd |
|
BLAKE2b-256 | db191841f680b2b5b20d52045c9a9e333ffe8ffefaed0520f5b5ac8b01cf6986 |
File details
Details for the file tnt_tensorflow-0.2.0-py3-none-any.whl
.
File metadata
- Download URL: tnt_tensorflow-0.2.0-py3-none-any.whl
- Upload date:
- Size: 12.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.10.1 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 39d5a538367b7072248fb87ce3e6a5a2d4ab2029fbf6eb30800da3928bddb0c3 |
|
MD5 | dbee41b61854c4694ceb66baf1eee6a4 |
|
BLAKE2b-256 | 05d654aa0b4d51635c3c47d9412e101fe0f2a642d182fc3392013d3f7eb2a6fc |