"Online mining triplet losses for Pytorch"
Project description
online_triplet_loss
PyTorch conversion of the excellent post on the same topic in Tensorflow. Simply an implementation of a triple loss with online mining of candidate triplets used in semi-supervised learning.
Install
pip install online_triplet_loss
Then import with:
from online_triplet_loss.losses import *
How to use
In these examples I use a really large margin, since the embedding space is so small. A more realistic margins seems to be between 0.1 and 2.0
from torch import nn
import torch
model = nn.Embedding(10, 10)
#from online_triplet_loss.losses import *
labels = torch.randint(high=10, size=(5,)) # our five labels
embeddings = model(labels)
print('Labels:', labels)
print('Embeddings:', embeddings)
loss = batch_hard_triplet_loss(labels, embeddings, margin=100)
print('Loss:', loss)
loss.backward()
Labels: tensor([2, 0, 8, 8, 6])
Embeddings: tensor([[-2.1302, -1.2041, -0.1863, 0.0326, -1.4517, -1.1717, 0.5536, 0.0911,
0.6340, 0.9799],
[ 0.3097, -0.9580, -1.3407, 0.9034, -1.6870, -2.3679, -0.1658, 1.6462,
0.7228, 1.1883],
[-0.4034, -1.0991, 0.2177, -0.4273, -0.7944, 1.0775, 0.9443, 1.3429,
0.1356, -0.9966],
[-0.4034, -1.0991, 0.2177, -0.4273, -0.7944, 1.0775, 0.9443, 1.3429,
0.1356, -0.9966],
[ 0.5503, -2.1766, 0.9868, 0.5284, 0.8618, 1.6336, -0.7546, -0.7549,
0.1552, -0.5286]], grad_fn=<EmbeddingBackward>)
Loss: tensor(96.3256, grad_fn=<MeanBackward1>)
#from online_triplet_loss.losses import *
embeddings = model(labels)
print('Labels:', labels)
print('Embeddings:', embeddings)
loss, fraction_pos = batch_all_triplet_loss(labels, embeddings, squared=False, margin=100)
print('Loss:', loss)
loss.backward()
Labels: tensor([2, 0, 8, 8, 6])
Embeddings: tensor([[-2.1302, -1.2041, -0.1863, 0.0326, -1.4517, -1.1717, 0.5536, 0.0911,
0.6340, 0.9799],
[ 0.3097, -0.9580, -1.3407, 0.9034, -1.6870, -2.3679, -0.1658, 1.6462,
0.7228, 1.1883],
[-0.4034, -1.0991, 0.2177, -0.4273, -0.7944, 1.0775, 0.9443, 1.3429,
0.1356, -0.9966],
[-0.4034, -1.0991, 0.2177, -0.4273, -0.7944, 1.0775, 0.9443, 1.3429,
0.1356, -0.9966],
[ 0.5503, -2.1766, 0.9868, 0.5284, 0.8618, 1.6336, -0.7546, -0.7549,
0.1552, -0.5286]], grad_fn=<EmbeddingBackward>)
tensor(95.8399, grad_fn=<DivBackward0>) tensor(1.)
Loss: tensor(95.8399, grad_fn=<DivBackward0>)
References
- Triplet Loss and Online Triplet Mining in Tensorflow
- Facenet paper
- adambielski's nice implementation (unfortunately context switches between CPU / GPU)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for online_triplet_loss-0.0.2.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | d232cf33da603fc9d8f16a2d75746f5ab0f7fd8ceb5c9a402f1a02aa6a52d6c1 |
|
MD5 | 3bf1505d1bc22773e267b12f99da6f0e |
|
BLAKE2b-256 | 10f85f14caec53fce6b2cea047f405c6ac591fc23823159f61cd40a50d903009 |
Close
Hashes for online_triplet_loss-0.0.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 93f294c7161fb9744f08a12612f5f8d3213ff3af7902b2702884d3ed226df2aa |
|
MD5 | b829a7f1f279ce59635074f914a9689f |
|
BLAKE2b-256 | 2b0147b6108e9cb61e7934f5dea1542d42d9dd556d0c65f24ec173e6bc4cb133 |