Skip to main content

TensorFlow port of PyTorch Image Models (timm) - image models with pretrained weights

Project description

TensorFlow-Image-Models

Introduction

TensorfFlow-Image-Models (tfimm) is a collection of image models with pretrained weights, obtained by porting architectures from timm to TensorFlow. The hope is that the number of available architectures will grow over time. For now, it contains vision transformers (ViT, DeiT, CaiT and Swin Transformers), MLP-Mixer models (MLP-Mixer, ResMLP, gMLP and ConvMixer) and ResNets.

This work would not have been possible wihout Ross Wightman's timm library and the work on PyTorch/TensorFlow interoperability in HuggingFace's transformer repository. I tried to make sure all source material is acknowledged. Please let me know if I have missed something.

Usage

Installation

The package can be installed via pip,

pip install tfimm

To load pretrained weights, timm needs to be installed. This is an optional dependency and can be installed via

pip install tfimm[timm]

Creating models

To load pretrained models use

import tfimm

model = tfimm.create_model("vit_tiny_patch16_224", pretrained="timm")

We can list available models with pretrained weights via

import tfimm

print(tfimm.list_models(pretrained="timm"))

Most models are pretrained on ImageNet or ImageNet-21k. If we want to use them for other tasks we need to change the number of classes in the classifier or remove the classifier altogether. We can do this by setting the nb_classes parameter in create_model. If nb_classes=0, the model will have no classification layer. If nb_classes is set to a value different from the default model config, the classification layer will be randomly initialized, while all other weights will be copied from the pretrained model.

The preprocessing function for each model can be created via

import tensorflow as tf
import tfimm

preprocess = tfimm.create_preprocessing("vit_tiny_patch16_224", dtype="float32")
img = tf.ones((1, 224, 224, 3), dtype="uint8")
img_preprocessed = preprocess(img)

Saving and loading models

All models are subclassed from tf.keras.Model (they are not functional models). They can still be saved and loaded using the SavedModel format.

>>> import tesnorflow as tf
>>> import tfimm
>>> model = tfimm.create_model("vit_tiny_patch16_224")
>>> type(model)
<class 'tfimm.architectures.vit.ViT'>
>>> model.save("/tmp/my_model")
>>> loaded_model = tf.keras.models.load_model("/tmp/my_model")
>>> type(loaded_model)
<class 'tfimm.architectures.vit.ViT'>

For this to work, the tfimm library needs to be imported before the model is loaded, since during the import process, tfimm is registering custom models with Keras. Otherwise, we obtain the following output

>>> import tensorflow as tf
>>> loaded_model = tf.keras.models.load_model("/tmp/my_model")
>>> type(loaded_model)
<class 'keras.saving.saved_model.load.Custom>ViT'>

Models

The following architectures are currently available:

Profiling

To understand how big each of the models is, I have done some profiling to measure

  • maximum batch size that fits in GPU memory and
  • throughput in images/second for both inference and backpropagation on K80 and V100 GPUs. For V100, measurements were done for both float32 and mixed precision.

The results can be found in the results/profiling_{k80, v100}.csv files.

For backpropagation, we use as loss the mean of model outputs

def backprop():
    with tf.GradientTape() as tape:
        output = model(x, training=True)
        loss = tf.reduce_mean(output)
        grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tfimm-0.1.3.tar.gz (47.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tfimm-0.1.3-py3-none-any.whl (59.8 kB view details)

Uploaded Python 3

File details

Details for the file tfimm-0.1.3.tar.gz.

File metadata

  • Download URL: tfimm-0.1.3.tar.gz
  • Upload date:
  • Size: 47.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.1.6 CPython/3.7.12 Linux/5.11.0-1021-azure

File hashes

Hashes for tfimm-0.1.3.tar.gz
Algorithm Hash digest
SHA256 55439aa678f671e1ae6ed90bc02be50ddb541f68b08143a7dc68dd94a244751b
MD5 ade1c5aeb6db8329b8d2d410d7f9de08
BLAKE2b-256 2c1490a765b5442c30fb04c1be967926073d15c23ccdc8c19f322f966cc405a6

See more details on using hashes here.

File details

Details for the file tfimm-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: tfimm-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 59.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.1.6 CPython/3.7.12 Linux/5.11.0-1021-azure

File hashes

Hashes for tfimm-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 1cb1d027a88de9e7da6b27526be7251d2b7285cd7a934fc58cb749e6a155366b
MD5 32b8f474369d5b68d1cc1dd6342b2e94
BLAKE2b-256 c52f7c6e1018edb499271efba2416a1106ef681721c7fb8267c8f5297a636ca2

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page