Skip to main content

Re-implementation of FocalNet for tensorflow 2.X

Project description

FocalNet: Focal Modulation Networks for Tensorflow

This repository contains a TensorFlow implementation of the paper Focal Modulation Networks. The paper proposes an attention-free architecture called focal modulation, which can dynamically adjust the focus of convolutional neural networks on different regions of the input. Focal modulation can improve the performance of various vision tasks, such as image classification, object detection, semantic segmentation and face recognition.

Focal Modulation brings several merits:

  • Translation-Invariance: It is performed for each target token with the context centered around it.
  • Explicit input-dependency: The modulator is computed by aggregating the short- and long-rage context from the input and then applied to the target token.
  • Spatial- and channel-specific: It first aggregates the context spatial-wise and then channel-wise, followed by an element-wise modulation.
  • Decoupled feature granularity: Query token preserves the invidual information at finest level, while coarser context is extracted surrounding it. They two are decoupled but connected through the modulation operation.
  • Easy to implement: We can implement both context aggregation and interaction in a very simple and light-weight way. It does not need softmax, multiple attention heads, feature map rolling or unfolding, etc.

This repository aims to reproduce the results of the paper using TensorFlow 2.4.1 and provide a modular and easy-to-use implementation of focal modulation networks. The code is based on the official PyTorch implementation of the paper, which can be found on the offical repository here . Only the classification part is implemented. Pretrained checkpoints have been converted on Tensorflow.

Installation

pip install focalnet-tf

Example


import cv2
import sys
import numpy as np
import os 
import tensorflow as tf
from focalnet import load_focalnet, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, imagenet1k, imagenet22k

def preprocess_image(image ):
    image = image/255.0
    image = (image - IMAGENET_DEFAULT_MEAN)/IMAGENET_DEFAULT_STD
    return np.expand_dims(image, axis=0)

def center_crop(image, output_shape):
    # Get the input shape
    h, w, c = image.shape

    # Get the output shape
    h_desired, w_desired = output_shape

    # Check if the output shape is valid
    if h_desired > h or w_desired > w  :
        raise ValueError("Output shape must be smaller than or equal to input shape and have the same number of channels.")

    # Compute the crop coordinates
    h_start = (h - h_desired) // 2
    h_end = h_start + h_desired
    w_start = (w - w_desired) // 2
    w_end = w_start + w_desired

    # Crop the image and return it
    return image[h_start:h_end, w_start:w_end, :]

image = cv2.cvtColor(cv2.imread("tests/dog.jpg"), cv2.COLOR_BGR2RGB)
image_crop = center_crop(image, (768, 768))
output_shape = (224, 224)
image_resized = cv2.resize(image_crop, output_shape)
inputs = preprocess_image(image_crop)

model = load_focalnet(model_name='focalnet_tiny_srf', pretrained=True, return_model=False, act_head="softmax")
output = model.predict(inputs)
print(output[0, np.argmax(output)])
print(imagenet22k[np.argmax(output)])

Acknowledgement

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

focalnet-tf-0.0.2.3.tar.gz (600.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

focalnet_tf-0.0.2.3-py3-none-any.whl (600.3 kB view details)

Uploaded Python 3

File details

Details for the file focalnet-tf-0.0.2.3.tar.gz.

File metadata

  • Download URL: focalnet-tf-0.0.2.3.tar.gz
  • Upload date:
  • Size: 600.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.7.10

File hashes

Hashes for focalnet-tf-0.0.2.3.tar.gz
Algorithm Hash digest
SHA256 831d4ab777a8840d86048e5bba35fd585f5336db8177199600aa37c8d5825e26
MD5 8eff3d3e215df5507a2cc48ee724d24e
BLAKE2b-256 3ad83bbcfa936ae80913c36fd0065cf42b5e2bbfe500d6f87de9b1bacd5859c0

See more details on using hashes here.

File details

Details for the file focalnet_tf-0.0.2.3-py3-none-any.whl.

File metadata

  • Download URL: focalnet_tf-0.0.2.3-py3-none-any.whl
  • Upload date:
  • Size: 600.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.7.10

File hashes

Hashes for focalnet_tf-0.0.2.3-py3-none-any.whl
Algorithm Hash digest
SHA256 871072703dca263dc20061ed38acbaeaebf90fcd6edeb10356cf0b3ba2ebcd67
MD5 c4bc9ff38c74c8e3d7dba8d17cb322a5
BLAKE2b-256 68fb0c9b9779e81d058ae72269b1c44e0ab7bd19f7303d104f84ed5f540a61ee

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page