Triplet Loss Utils for Pytorch Library.
Project description
Triplet Loss Utility for Pytorch Library.
TripletTorch
TripletTorch is a small pytorch utility for triplet loss projects. It provides simple way to create custom triplet datasets and common triplet mining loss techniques.
Install
Install the module using the pip utility ( may require to run as sudo ).
pip3 install triplettorch
Usage
Triplet Dataset
from triplettorch import TripletDataset
# Create a triplet dataset given:
# * labels : array of label ( class ) for each sample of the dataset
# * data_fn : method to access data for a given index in the dataset
# * size : number of samples in the dataset
# * n_sample: number of sample per draw ( to increase probability to
# contain valid triplets in a batch )
# Do not forget to concatenate batch dimension and sample dimension
# when used with a DataLoader as TripletDataset[ idx ] returns a
# ( batch_size, n_sample, ... ) dimension tensor for labels and data
dataset = TripletDataset( labels, data_fn, size, n_sample )
Triplet Mining
from triplettorch import AllTripletMiner, HardNegativeTripletMiner
# Define the triplet mining loss given:
# * margin: the margin float value from the triplet loss definition
miner = AllTripletMiner( .5 ).cuda( )
miner = HardNegativeTripletMiner( .5 ).cuda( )
# Use the loss in training given:
# * labels : array of label ( class ) for each sample of the batch
# * embeddings: output of the neural network for each sample of the batch
# Returns two values:
# * loss : triplet loss value
# * frac_pos: fraction of positive triplets
# None ( None HardNegativeTripletMiner )
loss, frac_pos = miner( labels, embeddings )
Example
The repository provides an example application with the MNIST dataset.
References
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
TripletTorch-0.1.2.tar.gz
(5.5 kB
view details)
Built Distribution
File details
Details for the file TripletTorch-0.1.2.tar.gz
.
File metadata
- Download URL: TripletTorch-0.1.2.tar.gz
- Upload date:
- Size: 5.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.36.1 CPython/3.6.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | f8a73c3eae348e3a56c2fafe3d7b20d962392b77604ee6ba3a9691eff87d7fd6 |
|
MD5 | 61357944c38f8321cc52a92fce184941 |
|
BLAKE2b-256 | 094253d62e2c287baf452e8d3f05814dfaaf745e76e20e13582cde77ccc08bd5 |
File details
Details for the file TripletTorch-0.1.2-py3-none-any.whl
.
File metadata
- Download URL: TripletTorch-0.1.2-py3-none-any.whl
- Upload date:
- Size: 6.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.36.1 CPython/3.6.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1825094af01e067d7afc2eed46fa58e3ceb51ff9f32febe31486152be9d8e991 |
|
MD5 | 5424f3076bc867253b923c65e1bb110b |
|
BLAKE2b-256 | 07d31c9f8221e173ed491c604472876e533beccb62ca2e20fd841e8c7d78b203 |