Metric learning layers with tf.keras
Project description
Simple metric learning via tf.keras
This package provides only a few metric learning losses below;
- ArcFace
- AdaCos
- CircleLoss
I have been greatly inspired by PyTorch Metric Learning.
Installation
$ pip install tf-simple-metric-learning
Usage
Provided layers are implemented via tf.keras.layers.Layer API, enables;
from tf_simple_metric_learning.layers import ArcFace
arcface = ArcFace(num_classes=NUM_CLASSES, margin=MARGIN, scale=SCALE)
Example notebook is in examples directory. Implement CircleLossCL (Class-level label version) layer top of EfficientNet and train it for Cars196 dataset;
import tensorflow as tf
from tf_simple_metric_learning.layers import ArcFace, AdaCos, CircleLossCL
inputs = tf.keras.layers.Input([*IMAGE_SIZE, 3], dtype=tf.uint8)
x = tf.cast(inputs, dtype=tf.float32)
x = tf.keras.applications.efficientnet.preprocess_input(x)
net = tf.keras.applications.EfficientNetB0(include_top=False, weights='imagenet', pooling='avg')
embeds = net(x)
labels = tf.keras.layers.Input([], dtype=tf.int32)
labels_onehot = tf.one_hot(labels, depth=num_classes)
# Create metric learning layer
# metric_layer = ArcFace(num_classes=num_classes, margin=0.5, scale=64)
# metric_layer = AdaCos(num_classes=num_classes)
metric_layer = CircleLossCL(num_classes=num_classes, margin=0.25, scale=256)
logits = metric_layer([embeds, labels_onehot])
model = tf.keras.Model(inputs=[inputs, labels], outputs=logits)
model.summary()
Note that you should feed labels as input into model in training because these layers require labels to forward.
In evaluation or prediction, above model requires both images and labels but labels is ignored in those metric learning layers. We only need to use dummy labels (ignored) with the target images because we can't access labels in evaluation or prediction.
References
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file tf-simple-metric-learning-0.1.2.tar.gz.
File metadata
- Download URL: tf-simple-metric-learning-0.1.2.tar.gz
- Upload date:
- Size: 4.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.0.10 CPython/3.8.5 Linux/5.3.0-1034-azure
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a02d2e43570fcb635d44ee76e2098fa21bccb2b7772cab4a62f75ee2de64dcd5
|
|
| MD5 |
fa98e50c24b2f128e8ccdb6adf142b78
|
|
| BLAKE2b-256 |
b03066805b184c40d2623f3961c64b608cea45e3d71cdea8f13346f05b552a89
|
File details
Details for the file tf_simple_metric_learning-0.1.2-py3-none-any.whl.
File metadata
- Download URL: tf_simple_metric_learning-0.1.2-py3-none-any.whl
- Upload date:
- Size: 5.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.0.10 CPython/3.8.5 Linux/5.3.0-1034-azure
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3e6517ffa74fc9a4e6cb39f7943fe4ad14552a5fd641ee8bc6b470a41b2d0d45
|
|
| MD5 |
5bcde4e9d95f969f5c655ea20d34997d
|
|
| BLAKE2b-256 |
3495115a941ecc3584ae778187af51486c286eed07c2cbf11c7b435b57fe0153
|