Extract images embeddings with variation of pre-trained torch models
Project description
Installation:
pip install torch_image2vector
Install from github:
pip install git+https://github.com/MajewskiLukasz/torch_image2vector.git
List of available models
ResNet Models
Model | Vector Length |
---|---|
resnet-18 | 512 |
resnet-34 | 512 |
resnet-50 | 2048 |
resnet-101 | 2048 |
resnet-152 | 2048 |
ResNeXt Models
Model | Vector Length |
---|---|
resnext50_32x4d | 2048 |
resnext101_32x8d | 2048 |
resnext101_32x8d_wsl | 2048 |
EfficientNet Models
Model | Vector Length |
---|---|
efficientnet_b0 | 1280 |
efficientnet_b1 | 1280 |
efficientnet_b2 | 1408 |
efficientnet_b3 | 1536 |
efficientnet_b4 | 1792 |
efficientnet_b5 | 2048 |
efficientnet_b6 | 2304 |
efficientnet_b7 | 2560 |
MobileNet Models
Model | Vector Length |
---|---|
mobilenet_v3_large | 960 |
ConvNext Models
Model | Vector Length |
---|---|
convnext_tiny | 768 |
convnext_small | 768 |
convnext_base | 1024 |
convnext_large | 1536 |
Note: If you are looking for a specific model not listed above, you can check the PyTorch Hub for more pre-trained models.
Recommended setup
in case of light embeddings extraction with good performance for similarity-search, its recommended to use
model=mobilenet_v3_large
with weights_version=mobilenet_v3_large-8738ca79
for results consistancy with other projects Refer to requirements.txt for
recommended versions of torch
and torchvision
these are not must but loading
specifc weights_version
requires matching specific torch
and torchvision
versions (or higher)
Test
bash run_test.sh
Use
1. Inference on loaded images
from PIL import Image
from torch_image2vector import Image2Vector
# Initialize Image2Vector
image2vector = Image2Vector(model="resnet50")
# Read in an image (rgb format)
img = Image.open('sample.jpg')
# Get a vector from image2vector
vec = image2vector.get_vec(img)
# [Alternative - batch processing] submit a list of images
vectors = image2vector.get_vec([img, img2, img3])
2. Inference on large number of images
For a large number of images, to improve inference speed, it is recommended to use the following pipeline:
import torch
from torch.utils.data import Dataset, DataLoader
from torch_image2vector import Image2Vector, ModelInference
from torch_image2vector.data import ImgPathsDataset
# Select batch size relative to your GPU capacity
batch_size = 128
num_workers = 8
# List of image files we want to run inference on.
image_files = ['sample.png', ...]
# Initialize Image2Vector with CUDA for speed.
image2vector = Image2Vector(cuda=True, model="mobilenet_v3_large", weights_version="mobilenet_v3_large-8738ca79")
# Create dataset from image paths and corresponding loader.
dataset = ImgPathsDataset(image_files, image2vector.transform)
loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
vec_list = []
for data in loader:
vec_list.append(image2vector._loader_forward(data.to(image2vector.device)))
# (N, dim) array of embeddings.
vec = np.vstack(vec_list)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torch_image2vector-1.0.0.tar.gz
.
File metadata
- Download URL: torch_image2vector-1.0.0.tar.gz
- Upload date:
- Size: 7.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | eedf5d4ed561fac4863584cb55369c1a0de02808af4f704eb36e02a797f97408 |
|
MD5 | 6557df83e601699a76497c129c561f9f |
|
BLAKE2b-256 | 42ac2ac0f0cae69cb845e6c9930c0e115ae3432953fab75e0d23c7a20f7c462c |
File details
Details for the file torch_image2vector-1.0.0-py3-none-any.whl
.
File metadata
- Download URL: torch_image2vector-1.0.0-py3-none-any.whl
- Upload date:
- Size: 8.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | cdb49cf81061bf66fe65b98e49f7b9d7c306c3cbee9d2f86981a87cfb55fa5d7 |
|
MD5 | a211ce321ec9d833eadd686bd4745e53 |
|
BLAKE2b-256 | 9124001ffcbc5cda1c74ae1dc667ba3049a7ab7a255cdc5159270ac68090c87f |