Skip to main content

Extract images embeddings with variation of pre-trained torch models

Project description

Installation:

pip install torch_image2vector

Install from github:
pip install git+https://github.com/MajewskiLukasz/torch_image2vector.git

List of available models

ResNet Models

Model Vector Length
resnet-18 512
resnet-34 512
resnet-50 2048
resnet-101 2048
resnet-152 2048

ResNeXt Models

Model Vector Length
resnext50_32x4d 2048
resnext101_32x8d 2048
resnext101_32x8d_wsl 2048

EfficientNet Models

Model Vector Length
efficientnet_b0 1280
efficientnet_b1 1280
efficientnet_b2 1408
efficientnet_b3 1536
efficientnet_b4 1792
efficientnet_b5 2048
efficientnet_b6 2304
efficientnet_b7 2560

MobileNet Models

Model Vector Length
mobilenet_v3_large 960

ConvNext Models

Model Vector Length
convnext_tiny 768
convnext_small 768
convnext_base 1024
convnext_large 1536

Note: If you are looking for a specific model not listed above, you can check the PyTorch Hub for more pre-trained models.

Recommended setup

in case of light embeddings extraction with good performance for similarity-search, its recommended to use model=mobilenet_v3_large with weights_version=mobilenet_v3_large-8738ca79 for results consistancy with other projects Refer to requirements.txt for recommended versions of torch and torchvision these are not must but loading specifc weights_version requires matching specific torch and torchvision versions (or higher)

Test

bash run_test.sh

Use

1. Inference on loaded images

from PIL import Image
from torch_image2vector import Image2Vector


# Initialize Image2Vector
image2vector = Image2Vector(model="resnet50")

# Read in an image (rgb format)
img = Image.open('sample.jpg')
# Get a vector from image2vector
vec = image2vector.get_vec(img)
# [Alternative - batch processing] submit a list of images
vectors = image2vector.get_vec([img, img2, img3])

2. Inference on large number of images

For a large number of images, to improve inference speed, it is recommended to use the following pipeline:

import torch
from torch.utils.data import Dataset, DataLoader
from torch_image2vector import Image2Vector, ModelInference
from torch_image2vector.data import ImgPathsDataset

# Select batch size relative to your GPU capacity
batch_size = 128
num_workers = 8
# List of image files we want to run inference on.
image_files = ['sample.png', ...]

# Initialize Image2Vector with CUDA for speed.
image2vector = Image2Vector(cuda=True, model="mobilenet_v3_large", weights_version="mobilenet_v3_large-8738ca79")
# Create dataset from image paths and corresponding loader.
dataset = ImgPathsDataset(image_files, image2vector.transform)
loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
vec_list = []
for data in loader:
    vec_list.append(image2vector._loader_forward(data.to(image2vector.device)))
# (N, dim) array of embeddings.
vec = np.vstack(vec_list)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_image2vector-1.0.0.tar.gz (7.2 kB view details)

Uploaded Source

Built Distribution

torch_image2vector-1.0.0-py3-none-any.whl (8.1 kB view details)

Uploaded Python 3

File details

Details for the file torch_image2vector-1.0.0.tar.gz.

File metadata

  • Download URL: torch_image2vector-1.0.0.tar.gz
  • Upload date:
  • Size: 7.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.3

File hashes

Hashes for torch_image2vector-1.0.0.tar.gz
Algorithm Hash digest
SHA256 eedf5d4ed561fac4863584cb55369c1a0de02808af4f704eb36e02a797f97408
MD5 6557df83e601699a76497c129c561f9f
BLAKE2b-256 42ac2ac0f0cae69cb845e6c9930c0e115ae3432953fab75e0d23c7a20f7c462c

See more details on using hashes here.

File details

Details for the file torch_image2vector-1.0.0-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_image2vector-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 cdb49cf81061bf66fe65b98e49f7b9d7c306c3cbee9d2f86981a87cfb55fa5d7
MD5 a211ce321ec9d833eadd686bd4745e53
BLAKE2b-256 9124001ffcbc5cda1c74ae1dc667ba3049a7ab7a255cdc5159270ac68090c87f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page