Skip to main content
Join the official 2019 Python Developers SurveyStart the survey!

Triplet Loss Utils for Pytorch Library.

Project description

Triplet Loss Utility for Pytorch Library.

TripletTorch

TripletTorch is a small pytorch utility for triplet loss projects. It provides simple way to create custom triplet datasets and common triplet mining loss techniques.

Install

Install the module using the pip utility ( may require to run as sudo ).

pip3 install triplettorch

Usage

Triplet Dataset

from triplettorch import TripletDataset

# Create a triplet dataset given:
#   * labels  : array of label ( class ) for each sample of the dataset
#   * data_fn : method to access data for a given index in the dataset
#   * size    : number of samples in the dataset
#   * n_sample: number of sample per draw ( to increase probability to
#               contain valid triplets in a batch )
# Do not forget to concatenate batch dimension and sample dimension
# when used with a DataLoader as TripletDataset[ idx ] returns a
# ( batch_size, n_sample, ... ) dimension tensor for labels and data
dataset = TripletDataset( labels, data_fn, size, n_sample )

Triplet Mining

from triplettorch import AllTripletMiner, HardNegativeTripletMiner

# Define the triplet mining loss given:
#   * margin: the margin float value from the triplet loss definition
miner          = AllTripletMiner( .5 ).cuda( )
miner          = HardNegativeTripletMiner( .5 ).cuda( )

# Use the loss in training given:
#   * labels    : array of label ( class ) for each sample of the batch
#   * embeddings: output of the neural network for each sample of the batch
# Returns two values:
#   * loss    : triplet loss value
#   * frac_pos: fraction of positive triplets
#               None ( None HardNegativeTripletMiner )
loss, frac_pos = miner( labels, embeddings )

Example

The repository provides an example application with the MNIST dataset.

 MNIST

References

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for TripletTorch, version 0.1.2
Filename, size File type Python version Upload date Hashes
Filename, size TripletTorch-0.1.2-py3-none-any.whl (6.1 kB) File type Wheel Python version py3 Upload date Hashes View hashes
Filename, size TripletTorch-0.1.2.tar.gz (5.5 kB) File type Source Python version None Upload date Hashes View hashes

Supported by

Elastic Elastic Search Pingdom Pingdom Monitoring Google Google BigQuery Sentry Sentry Error logging AWS AWS Cloud computing DataDog DataDog Monitoring Fastly Fastly CDN SignalFx SignalFx Supporter DigiCert DigiCert EV certificate StatusPage StatusPage Status page