Skip to main content

Triplet Loss Utils for Pytorch Library.

Project description

Triplet Loss Utility for Pytorch Library.

TripletTorch

TripletTorch is a small pytorch utility for triplet loss projects. It provides simple way to create custom triplet datasets and common triplet mining loss techniques.

Install

Install the module using the pip utility ( may require to run as sudo ).

pip3 install triplettorch

Usage

Triplet Dataset

from triplettorch import TripletDataset

# Create a triplet dataset given:
#   * labels  : array of label ( class ) for each sample of the dataset
#   * data_fn : method to access data for a given index in the dataset
#   * size    : number of samples in the dataset
#   * n_sample: number of sample per draw ( to increase probability to
#               contain valid triplets in a batch )
# Do not forget to concatenate batch dimension and sample dimension
# when used with a DataLoader as TripletDataset[ idx ] returns a
# ( batch_size, n_sample, ... ) dimension tensor for labels and data
dataset = TripletDataset( labels, data_fn, size, n_sample )

Triplet Mining

from triplettorch import AllTripletMiner, HardNegativeTripletMiner

# Define the triplet mining loss given:
#   * margin: the margin float value from the triplet loss definition
miner          = AllTripletMiner( .5 ).cuda( )
miner          = HardNegativeTripletMiner( .5 ).cuda( )

# Use the loss in training given:
#   * labels    : array of label ( class ) for each sample of the batch
#   * embeddings: output of the neural network for each sample of the batch
# Returns two values:
#   * loss    : triplet loss value
#   * frac_pos: fraction of positive triplets
#               None ( None HardNegativeTripletMiner )
loss, frac_pos = miner( labels, embeddings )

Example

The repository provides an example application with the MNIST dataset.

 MNIST

References

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

TripletTorch-0.1.3.tar.gz (5.5 kB view hashes)

Uploaded Source

Built Distribution

TripletTorch-0.1.3-py3-none-any.whl (6.1 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page