"Online mining triplet losses for Pytorch"
Project description
online_triplet_loss
PyTorch conversion of the excellent post on the same topic in Tensorflow. Simply an implementation of a triple loss with online mining of candidate triplets used in semi-supervised learning.
Install
pip install online_triplet_loss
Then import with:
from online_triplet_loss.losses import *
How to use
In these examples I use a really large margin, since the embedding space is so small. A more realistic margins seems to be between 0.1 and 2.0
from torch import nn
import torch
model = nn.Embedding(10, 10)
#from online_triplet_loss.losses import *
labels = torch.randint(high=10, size=(5,)) # our five labels
embeddings = model(labels)
print('Labels:', labels)
print('Embeddings:', embeddings)
loss = batch_hard_triplet_loss(labels, embeddings, margin=100)
print('Loss:', loss)
loss.backward()
Labels: tensor([0, 7, 7, 5, 5])
Embeddings: tensor([[ 1.7146, -0.3138, 0.1500, -1.3602, 0.6112, 1.9415, -0.0872, -0.5365,
-0.6287, -1.2523],
[ 0.3933, -1.9714, 1.7608, -0.4584, 0.9668, -1.4512, -0.2314, 1.8080,
0.4513, -0.3509],
[ 0.3933, -1.9714, 1.7608, -0.4584, 0.9668, -1.4512, -0.2314, 1.8080,
0.4513, -0.3509],
[-1.3622, -1.2098, -0.4699, 1.3565, 1.4588, 0.7476, 0.1563, 2.0376,
0.7811, -0.0996],
[-1.3622, -1.2098, -0.4699, 1.3565, 1.4588, 0.7476, 0.1563, 2.0376,
0.7811, -0.0996]], grad_fn=<EmbeddingBackward>)
Loss: tensor(95.6246, grad_fn=<MeanBackward1>)
#from online_triplet_loss.losses import *
embeddings = model(labels)
print('Labels:', labels)
print('Embeddings:', embeddings)
loss, fraction_pos = batch_all_triplet_loss(labels, embeddings, squared=False, margin=100)
print('Loss:', loss)
loss.backward()
Labels: tensor([0, 7, 7, 5, 5])
Embeddings: tensor([[ 1.7146, -0.3138, 0.1500, -1.3602, 0.6112, 1.9415, -0.0872, -0.5365,
-0.6287, -1.2523],
[ 0.3933, -1.9714, 1.7608, -0.4584, 0.9668, -1.4512, -0.2314, 1.8080,
0.4513, -0.3509],
[ 0.3933, -1.9714, 1.7608, -0.4584, 0.9668, -1.4512, -0.2314, 1.8080,
0.4513, -0.3509],
[-1.3622, -1.2098, -0.4699, 1.3565, 1.4588, 0.7476, 0.1563, 2.0376,
0.7811, -0.0996],
[-1.3622, -1.2098, -0.4699, 1.3565, 1.4588, 0.7476, 0.1563, 2.0376,
0.7811, -0.0996]], grad_fn=<EmbeddingBackward>)
tensor(95.4382, grad_fn=<DivBackward0>) tensor(1.)
Loss: tensor(95.4382, grad_fn=<DivBackward0>)
References
- Triplet Loss and Online Triplet Mining in Tensorflow
- Facenet paper
- adambielski's nice implementation (unfortunately context switches between CPU / GPU)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for online_triplet_loss-0.0.1.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 46cd6c998feba3d6c68adb4c155531eaa2a056d6d9ab56357998337fdaa13c03 |
|
MD5 | 6626e05833dfa8238b9d64fe73be033a |
|
BLAKE2b-256 | 4a6104f6ad5a3edd7e34f8f799f871d53741d7714f67b705a8969e388be9f08b |
Close
Hashes for online_triplet_loss-0.0.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3822e1214adf97e5a811a611cc1cda631c60ef5ea7bd75d1ca7b7c4274251fb8 |
|
MD5 | e5e89761eb1962fdbd5d8e5d805e4194 |
|
BLAKE2b-256 | 694c72292ddea3e837035895309044e6f35ebce37140843f9f8e81f767dc3620 |