"Online mining triplet losses for Pytorch"
Project description
online_triplet_loss
PyTorch conversion of the excellent post on the same topic in Tensorflow. Simply an implementation of a triple loss with online mining of candidate triplets used in semi-supervised learning.
Install
pip install online_triplet_loss
Then import with:
from online_triplet_loss.losses import *
PS: Requires Pytorch version 1.1.0 or above to use.
How to use
In these examples I use a really large margin, since the embedding space is so small. A more realistic margins seems to be between 0.1 and 2.0
from torch import nn
import torch
model = nn.Embedding(10, 10)
#from online_triplet_loss.losses import *
labels = torch.randint(high=10, size=(5,)) # our five labels
embeddings = model(labels)
print('Labels:', labels)
print('Embeddings:', embeddings)
loss = batch_hard_triplet_loss(labels, embeddings, margin=100)
print('Loss:', loss)
loss.backward()
Labels: tensor([6, 1, 3, 6, 6])
Embeddings: tensor([[-1.1335, 0.3364, -3.0174, -0.8732, -0.9301, 1.3619, 0.3746, 0.0457,
0.0180, -0.4500],
[ 1.0757, -0.8420, -0.7630, -0.0746, 1.1545, 0.4017, 0.5587, 1.7947,
0.1992, -2.2288],
[ 0.2646, 1.2383, 0.1949, 0.5743, -0.8460, -0.9929, -2.0350, 0.2095,
0.2129, -0.4855],
[-1.1335, 0.3364, -3.0174, -0.8732, -0.9301, 1.3619, 0.3746, 0.0457,
0.0180, -0.4500],
[-1.1335, 0.3364, -3.0174, -0.8732, -0.9301, 1.3619, 0.3746, 0.0457,
0.0180, -0.4500]], grad_fn=<EmbeddingBackward>)
Loss: tensor(95.1271, grad_fn=<MeanBackward0>)
#from online_triplet_loss.losses import *
embeddings = model(labels)
print('Labels:', labels)
print('Embeddings:', embeddings)
loss, fraction_pos = batch_all_triplet_loss(labels, embeddings, squared=False, margin=100)
print('Loss:', loss)
loss.backward()
Labels: tensor([6, 1, 3, 6, 6])
Embeddings: tensor([[-1.1335, 0.3364, -3.0174, -0.8732, -0.9301, 1.3619, 0.3746, 0.0457,
0.0180, -0.4500],
[ 1.0757, -0.8420, -0.7630, -0.0746, 1.1545, 0.4017, 0.5587, 1.7947,
0.1992, -2.2288],
[ 0.2646, 1.2383, 0.1949, 0.5743, -0.8460, -0.9929, -2.0350, 0.2095,
0.2129, -0.4855],
[-1.1335, 0.3364, -3.0174, -0.8732, -0.9301, 1.3619, 0.3746, 0.0457,
0.0180, -0.4500],
[-1.1335, 0.3364, -3.0174, -0.8732, -0.9301, 1.3619, 0.3746, 0.0457,
0.0180, -0.4500]], grad_fn=<EmbeddingBackward>)
tensor(94.9947, grad_fn=<DivBackward0>) tensor(1.)
Loss: tensor(94.9947, grad_fn=<DivBackward0>)
References
- Triplet Loss and Online Triplet Mining in Tensorflow
- Facenet paper
- adambielski's nice implementation (unfortunately context switches between CPU / GPU)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file online_triplet_loss-0.0.6.tar.gz
.
File metadata
- Download URL: online_triplet_loss-0.0.6.tar.gz
- Upload date:
- Size: 5.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.1.0.post20200710 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2849bea4470ef07302fd52d1ddb0bbb488a5e999d7508cece18212cbc7543357 |
|
MD5 | 1975239c7e63200938e3aa5d5cd4a085 |
|
BLAKE2b-256 | 9ae8ef6a783743a63286ea1d3ac08f250b3d3a9e2ef25b0fee31d94af24997e2 |
File details
Details for the file online_triplet_loss-0.0.6-py3-none-any.whl
.
File metadata
- Download URL: online_triplet_loss-0.0.6-py3-none-any.whl
- Upload date:
- Size: 6.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.1.0.post20200710 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d0226188833930ddfc0c27fe4f463fd0da91e2ab37b93fa334ae3bb5b4180e5f |
|
MD5 | 4c6ea5a42c0e03ac70d530cf49699d41 |
|
BLAKE2b-256 | b175501a4cd77f518d0988e641831fa8a225a48513b6217a731aac202bc56ab4 |