Re-implementation of FocalNet for tensorflow 2.X
Project description
FocalNet: Focal Modulation Networks for Tensorflow
This repository contains a TensorFlow implementation of the paper Focal Modulation Networks. The paper proposes an attention-free architecture called focal modulation, which can dynamically adjust the focus of convolutional neural networks on different regions of the input. Focal modulation can improve the performance of various vision tasks, such as image classification, object detection, semantic segmentation and face recognition.
Focal Modulation brings several merits:
- Translation-Invariance: It is performed for each target token with the context centered around it.
- Explicit input-dependency: The modulator is computed by aggregating the short- and long-rage context from the input and then applied to the target token.
- Spatial- and channel-specific: It first aggregates the context spatial-wise and then channel-wise, followed by an element-wise modulation.
- Decoupled feature granularity: Query token preserves the invidual information at finest level, while coarser context is extracted surrounding it. They two are decoupled but connected through the modulation operation.
- Easy to implement: We can implement both context aggregation and interaction in a very simple and light-weight way. It does not need softmax, multiple attention heads, feature map rolling or unfolding, etc.
This repository aims to reproduce the results of the paper using TensorFlow 2.4.1 and provide a modular and easy-to-use implementation of focal modulation networks. The code is based on the official PyTorch implementation of the paper, which can be found on the offical repository here . Only the classification part is implemented. Pretrained checkpoints have been converted on Tensorflow.
Installation
pip install focalnet-tf
Example
import cv2
import sys
import numpy as np
import os
import tensorflow as tf
from focalnet import load_focalnet, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, imagenet1k, imagenet22k
def preprocess_image(image ):
image = image/255.0
image = (image - IMAGENET_DEFAULT_MEAN)/IMAGENET_DEFAULT_STD
return np.expand_dims(image, axis=0)
def center_crop(image, output_shape):
# Get the input shape
h, w, c = image.shape
# Get the output shape
h_desired, w_desired = output_shape
# Check if the output shape is valid
if h_desired > h or w_desired > w :
raise ValueError("Output shape must be smaller than or equal to input shape and have the same number of channels.")
# Compute the crop coordinates
h_start = (h - h_desired) // 2
h_end = h_start + h_desired
w_start = (w - w_desired) // 2
w_end = w_start + w_desired
# Crop the image and return it
return image[h_start:h_end, w_start:w_end, :]
image = cv2.cvtColor(cv2.imread("tests/dog.jpg"), cv2.COLOR_BGR2RGB)
image_crop = center_crop(image, (768, 768))
output_shape = (224, 224)
image_resized = cv2.resize(image_crop, output_shape)
inputs = preprocess_image(image_crop)
model = load_focalnet(model_name='focalnet_tiny_srf', pretrained=True, return_model=False, act_head="softmax")
output = model.predict(inputs)
print(output[0, np.argmax(output)])
print(imagenet22k[np.argmax(output)])
Acknowledgement
- paper : https://arxiv.org/abs/2203.11926 from Jianwei Yang et al.
- pytorch implementation : https://github.com/microsoft/FocalNet
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for focalnet_tf-0.0.2.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2f7abff6f7f22e05f0546f8e894d37d95c5254a026f4c65b8c70f06c824c4dda |
|
MD5 | a94e2edb8c6de4edc5f2e65942e2677a |
|
BLAKE2b-256 | c36ef03c46ccc81f3776fe963a7e0dfe5efc25f5a52e78bd48b00e4b34bee8a4 |