Skip to main content

Metric learning layers with tf.keras

Project description

Simple metric learning via tf.keras

This package provides only a few metric learning losses below;

  • ArcFace
  • AdaCos
  • CircleLoss

I have been greatly inspired by PyTorch Metric Learning.

Installation

$ pip install tf-simple-metric-learning

Usage

Provided layers are implemented via tf.keras.layers.Layer API, enables;

from tf_simple_metric_learning.layers import ArcFace

arcface = ArcFace(num_classes=NUM_CLASSES, margin=MARGIN, scale=SCALE)

Example notebook is in examples directory. Implement CircleLossCL (Class-level label version) layer top of EfficientNet and train it for Cars196 dataset;

import tensorflow as tf
from tf_simple_metric_learning.layers import ArcFace, AdaCos, CircleLossCL

inputs = tf.keras.layers.Input([*IMAGE_SIZE, 3], dtype=tf.uint8)
x = tf.cast(inputs, dtype=tf.float32)
x = tf.keras.applications.efficientnet.preprocess_input(x)

net = tf.keras.applications.EfficientNetB0(include_top=False, weights='imagenet', pooling='avg')
embeds = net(x)

labels = tf.keras.layers.Input([], dtype=tf.int32)
labels_onehot = tf.one_hot(labels, depth=num_classes)

# Create metric learning layer
# metric_layer = ArcFace(num_classes=num_classes, margin=0.5, scale=64)
# metric_layer = AdaCos(num_classes=num_classes)
metric_layer = CircleLossCL(num_classes=num_classes, margin=0.25, scale=256)

logits = metric_layer([embeds, labels_onehot])

model = tf.keras.Model(inputs=[inputs, labels], outputs=logits)
model.summary()

Note that you should feed labels as input into model in training because these layers require labels to forward.

In evaluation or prediction, above model requires both images and labels but labels is ignored in those metric learning layers. We only need to use dummy labels (ignored) with the target images because we can't access labels in evaluation or prediction.

References

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tf-simple-metric-learning-0.1.2.tar.gz (4.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tf_simple_metric_learning-0.1.2-py3-none-any.whl (5.0 kB view details)

Uploaded Python 3

File details

Details for the file tf-simple-metric-learning-0.1.2.tar.gz.

File metadata

  • Download URL: tf-simple-metric-learning-0.1.2.tar.gz
  • Upload date:
  • Size: 4.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.0.10 CPython/3.8.5 Linux/5.3.0-1034-azure

File hashes

Hashes for tf-simple-metric-learning-0.1.2.tar.gz
Algorithm Hash digest
SHA256 a02d2e43570fcb635d44ee76e2098fa21bccb2b7772cab4a62f75ee2de64dcd5
MD5 fa98e50c24b2f128e8ccdb6adf142b78
BLAKE2b-256 b03066805b184c40d2623f3961c64b608cea45e3d71cdea8f13346f05b552a89

See more details on using hashes here.

File details

Details for the file tf_simple_metric_learning-0.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for tf_simple_metric_learning-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 3e6517ffa74fc9a4e6cb39f7943fe4ad14552a5fd641ee8bc6b470a41b2d0d45
MD5 5bcde4e9d95f969f5c655ea20d34997d
BLAKE2b-256 3495115a941ecc3584ae778187af51486c286eed07c2cbf11c7b435b57fe0153

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page