Image similarity, metric learning loss functions for TensorFlow 2+.
Project description
tf-metric-learning
Overview
Minimalistic open-source library for metric learning written in TensorFlow2, TF-Addons, Numpy, OpenCV(CV2) and Annoy. This repository contains a TensorFlow2+/tf.keras implementation some of the loss functions and miners. This repository was inspired by pytorch-metric-learning.
Installation
Prerequirements:
pip install tensorflow
pip install tensorflow-addons
pip install annoy
pip install opencv-contrib-python
This library:
pip install tf-metric-learning
Features
- All the loss functions are implemented as tf.keras.layers.Layer
- Callbacks for Computing Recall, Visualize Embeddings in TensorBoard Projector
- Simple Mining mechanism with Annoy
- Combine multiple loss functions/layers in one model
Open-source repos
This library contains code that has been adapted and modified from the following great open-source repos, without them this will be not possible (THANK YOU):
TODO
- Discriminative layer optimizer (different learning rates) for Loss with weights (Proxy, SoftTriple, ...) TODO
- Some Tests 😇
- Improve and add more minerss
Examples
import tensorflow as tf
import numpy as np
from tf_metric_learning.layers import SoftTripleLoss
from tf_metric_learning.utils.constants import EMBEDDINGS, LABELS
num_class, num_centers, embedding_size = 10, 2, 256
inputs = tf.keras.Input(shape=(embedding_size), name=EMBEDDINGS)
input_label = tf.keras.layers.Input(shape=(1,), name=LABELS)
output_tensor = SoftTripleLoss(num_class, num_centers, embedding_size)({EMBEDDINGS:inputs, LABELS:input_label})
model = tf.keras.Model(inputs=[inputs, input_label], outputs=output_tensor)
model.compile(optimizer="adam")
data = {EMBEDDINGS : np.asarray([np.zeros(256) for i in range(1000)]), LABELS: np.zeros(1000, dtype=np.float32)}
model.fit(data, None, epochs=10, batch_size=10)
More complex scenarios:
- Complex example with NPair Loss + Multi Similarity + Classification and Mining
- SoftTriple Training on CIFAR 10
- ProxyAnchor Loss using tf.data.Dataset
- Triplet Training with Mining
- Contrastive Training
- Classification baseline
Features
Loss functions
- MultiSimilarityLoss ✅
- ProxyAnchorLoss ✅
- SoftTripleLoss ✅
- NPairLoss ✅
- TripletLoss ✅
- ContrastiveLoss ✅
Miners
- MaximumLossMiner [TODO]
- TripletAnnoyMiner ✅
Evaluators
- AnnoyEvaluator Callback: for evaluation Recall@K, you will need to install Spotify annoy library.
import tensorflow as tf
from tf_metric_learning.utils.recall import AnnoyEvaluatorCallback
evaluator = AnnoyEvaluatorCallback(
base_network,
{"images": test_images[:divide], "labels": test_labels[:divide]}, # images stored to index
{"images": test_images[divide:], "labels": test_labels[divide:]}, # images to query
normalize_fn=lambda images: images / 255.0,
normalize_eb=True,
eb_size=embedding_size,
freq=1,
)
Visualizations
- Tensorboard Projector Callback
import tensorflow as tf
from tf_metric_learning.utils.projector import TBProjectorCallback
def normalize_images(images):
return images/255.0
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()
...
projector = TBProjectorCallback(
base_model,
"tb/projector",
test_images, # list of images
np.squeeze(test_labels),
normalize_eb=True,
normalize_fn=normalize_images
)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
tf-metric-learning-1.0.10.tar.gz
(18.3 kB
view hashes)
Built Distribution
Close
Hashes for tf-metric-learning-1.0.10.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6c05ce77ff1ffc24313b4a29b26255c55942b840f845278df66c9a19b0e08f61 |
|
MD5 | 43a9a972d6abc0e568352db3acfa834c |
|
BLAKE2b-256 | 6a46a0c2de74b94305e87f4ba88bd4ffb99ff5d4959c86b0de86d0bd4022c510 |
Close
Hashes for tf_metric_learning-1.0.10-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6f32dfe241c69c02a80a2cc5d625f74e6eefc39ad19cf2328ad7017e6413279a |
|
MD5 | 915549a873f45d2e2cf58fd47d6ea316 |
|
BLAKE2b-256 | 781b15b536a8659cb361f52acbf210524c4c10158d05581acd5bd8b23e45a889 |