Skip to main content

An Implementation of of Transformer in Transformer for image classification, attention inside local patches

Project description

Transformer-in-Transformer Twitter

PyPI Open In Colab Upload Python Package Lint Code Base Code style: black

GitHub License GitHub stars GitHub followers Twitter Follow

An Implementation of the Transformer in Transformer paper by Han et al. for image classification, attention inside local patches. Transformer in Transformers uses pixel level attention paired with patch level attention for image classification, in TensorFlow.

Installation

Run the following to install:

pip install tnt-tensorflow

Developing tnt-tensorflow

To install tnt-tensorflow, along with tools you need to develop and test, run the following in your virtualenv:

git clone https://github.com/Rishit-dagli/Transformer-in-Transformer.git
# or clone your own fork

cd tnt
pip install -e .[dev]

Usage

import tensorflow as tf
from tnt import TNT

tnt = TNT(
    image_size=256,  # size of image
    patch_dim=512,  # dimension of patch token
    pixel_dim=24,  # dimension of pixel token
    patch_size=16,  # patch size
    pixel_size=4,  # pixel size
    depth=5,  # depth
    num_classes=1000,  # output number of classes
    attn_dropout=0.1,  # attention dropout
    ff_dropout=0.1,  # feedforward dropout
)

img = tf.random.uniform(shape=[5, 3, 256, 256])
logits = tnt(img) # (5, 1000)

Want to Contribute 🙋‍♂️?

Awesome! If you want to contribute to this project, you're always welcome! See Contributing Guidelines. You can also take a look at open issues for getting more information about current or upcoming tasks.

Want to discuss? 💬

Have any questions, doubts or want to present your opinions, views? You're always welcome. You can start discussions.

Citation

@misc{han2021transformer,
      title={Transformer in Transformer}, 
      author={Kai Han and An Xiao and Enhua Wu and Jianyuan Guo and Chunjing Xu and Yunhe Wang},
      year={2021},
      eprint={2103.00112},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

License

Copyright 2020 Rishit Dagli

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tnt-tensorflow-0.1.0.tar.gz (9.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tnt_tensorflow-0.1.0-py3-none-any.whl (10.2 kB view details)

Uploaded Python 3

File details

Details for the file tnt-tensorflow-0.1.0.tar.gz.

File metadata

  • Download URL: tnt-tensorflow-0.1.0.tar.gz
  • Upload date:
  • Size: 9.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.0 importlib_metadata/4.8.2 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.9

File hashes

Hashes for tnt-tensorflow-0.1.0.tar.gz
Algorithm Hash digest
SHA256 0b95c6dad370a248a2e0ddc13b6b31e12520a214aa9963505ebed23542b7c18b
MD5 a6375dca603b5683f53668426f198e06
BLAKE2b-256 49efdba0b8726671c4f06f85549ba1647a56244a59b94a262eceb2792a69f36a

See more details on using hashes here.

File details

Details for the file tnt_tensorflow-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: tnt_tensorflow-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 10.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.0 importlib_metadata/4.8.2 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.9

File hashes

Hashes for tnt_tensorflow-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b840dc88de311811a020ea2efd294b4e2bef2a46b7a1fa3f758c63ff98d03cd5
MD5 abec26e425d2f8d5f0a5e0e3d031ef9d
BLAKE2b-256 303b1627a5ecdecb324ac4852e13057ab9e20a326b8f0bba5b2e2a4ad8b3c0de

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page