Skip to main content

Use pre-trained models in PyTorch to extract vector embeddings for any image

Project description

Image 2 Vec with PyTorch

Medium post on building the first version from scratch: https://becominghuman.ai/extract-a-feature-vector-for-any-image-with-pytorch-9717561d1d4c

Looking for a simpler image vector integration for your project? Check out our free API at https://latentvector.space

Applications of image embeddings:

  • Ranking for recommender systems
  • Clustering images to different categories
  • Classification tasks
  • Image compression

Available models

  • Resnet-18 (CPU, GPU)
    • Returns vector length 512
  • Alexnet (CPU, GPU)
    • Returns vector length 4096
  • Vgg-11 (CPU, GPU)
    • Returns vector length 4096
  • Densenet (CPU, GPU)
    • Returns vector length 1024

Installation

Tested on Python 3.6

Requires Pytorch: http://pytorch.org/

pip install img2vec_pytorch

Run test

python -m img2vec_pytorch.test_img_to_vec

Using img2vec as a library

from img2vec_pytorch import Img2Vec
from PIL import Image

# Initialize Img2Vec with GPU
img2vec = Img2Vec(cuda=True)

# Read in an image (rgb format)
img = Image.open('test.jpg')
# Get a vector from img2vec, returned as a torch FloatTensor
vec = img2vec.get_vec(img, tensor=True)
# Or submit a list
vectors = img2vec.get_vec(list_of_PIL_images)
For running the example, you will additionally need:
  • Pillow: pip install Pillow
  • Sklearn pip install scikit-learn

Running the example

git clone https://github.com/christiansafka/img2vec.git

cd img2vec/example

python test_img_similarity.py

Expected output

Which filename would you like similarities for?
cat.jpg
0.72832 cat2.jpg
0.641478 catdog.jpg
0.575845 face.jpg
0.516689 face2.jpg

Which filename would you like similarities for?
face2.jpg
0.668525 face.jpg
0.516689 cat.jpg
0.50084 cat2.jpg
0.484863 catdog.jpg

Try adding your own photos!

Img2Vec Params

cuda = (True, False)   # Run on GPU?     default: False
model = ('resnet-18', 'alexnet', 'vgg', 'densenet')   # Which model to use?     default: 'resnet-18'

Advanced users


Read only file systems

If you use this library from the app running in read only environment (for example, docker container), specify writable directory where app can store pre-trained models.

export TORCH_HOME=/tmp/torch

Additional Parameters

layer = 'layer_name' or int   # For advanced users, which layer of the model to extract the output from.   default: 'avgpool'
layer_output_size = int   # Size of the output of your selected layer

Resnet-18

Defaults: (layer = 'avgpool', layer_output_size = 512)
Layer parameter must be an string representing the name of a layer below

conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
bn1 = nn.BatchNorm2d(64)
relu = nn.ReLU(inplace=True)
maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
layer1 = self._make_layer(block, 64, layers[0])
layer2 = self._make_layer(block, 128, layers[1], stride=2)
layer3 = self._make_layer(block, 256, layers[2], stride=2)
layer4 = self._make_layer(block, 512, layers[3], stride=2)
avgpool = nn.AvgPool2d(7)
fc = nn.Linear(512 * block.expansion, num_classes)

Alexnet

Defaults: (layer = 2, layer_output_size = 4096)
Layer parameter must be an integer representing one of the layers below

alexnet.classifier = nn.Sequential(
            7. nn.Dropout(),                  < - output_size = 9216
            6. nn.Linear(256 * 6 * 6, 4096),  < - output_size = 4096
            5. nn.ReLU(inplace=True),         < - output_size = 4096
            4. nn.Dropout(),		      < - output_size = 4096
            3. nn.Linear(4096, 4096),	      < - output_size = 4096
            2. nn.ReLU(inplace=True),         < - output_size = 4096
            1. nn.Linear(4096, num_classes),  < - output_size = 4096
        )

Vgg

Defaults: (layer = 2, layer_output_size = 4096)

vgg.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )

Densenet

Defaults: (layer = 1 from features, layer_output_size = 1024)

densenet.features = nn.Sequential(OrderedDict([
	('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2,
						padding=3, bias=False)),
	('norm0', nn.BatchNorm2d(num_init_features)),
	('relu0', nn.ReLU(inplace=True)),
	('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
]))

To-do

  • Benchmark speed and accuracy
  • Add ability to fine-tune on input data
  • Export documentation to a normal place

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

img2vec_pytorch-1.0.0.tar.gz (5.7 kB view details)

Uploaded Source

Built Distribution

img2vec_pytorch-1.0.0-py3-none-any.whl (6.4 kB view details)

Uploaded Python 3

File details

Details for the file img2vec_pytorch-1.0.0.tar.gz.

File metadata

  • Download URL: img2vec_pytorch-1.0.0.tar.gz
  • Upload date:
  • Size: 5.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.9

File hashes

Hashes for img2vec_pytorch-1.0.0.tar.gz
Algorithm Hash digest
SHA256 e7cde5bda488a2e1743ef68cc3064e6810de9aa0dd57bf68af86c91ed4a371e2
MD5 d4735b418be3ec2372eba0dc3d83081f
BLAKE2b-256 af6831aedb14b9654da0e67ecc6cf23cb805f1eece08dee0e82a65114e77e101

See more details on using hashes here.

File details

Details for the file img2vec_pytorch-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: img2vec_pytorch-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 6.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.9

File hashes

Hashes for img2vec_pytorch-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 78a54570bde385e0cd7ec2fc8f61442aee8d7b33a389672785a57cd2c1b039e6
MD5 59abc8002ee7f9db842a0d4db57ed5fa
BLAKE2b-256 0426fee404cde9f37e0521e1db9d01b05b2251f39ba924b3e8a8e1c4b91eea48

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page