Skip to main content

Triplet Loss Utils for Pytorch Library.

Project description

Triplet Loss Utility for Pytorch Library.

TripletTorch

TripletTorch is a small pytorch utility for triplet loss projects. It provides simple way to create custom triplet datasets and common triplet mining loss techniques.

Install

Install the module using the pip utility ( may require to run as sudo ).

pip3 install triplettorch

Usage

Triplet Dataset

from triplettorch import TripletDataset

# Create a triplet dataset given:
#   * labels  : array of label ( class ) for each sample of the dataset
#   * data_fn : method to access data for a given index in the dataset
#   * size    : number of samples in the dataset
#   * n_sample: number of sample per draw ( to increase probability to
#               contain valid triplets in a batch )
# Do not forget to concatenate batch dimension and sample dimension
# when used with a DataLoader as TripletDataset[ idx ] returns a
# ( batch_size, n_sample, ... ) dimension tensor for labels and data
dataset = TripletDataset( labels, data_fn, size, n_sample )

Triplet Mining

from triplettorch import AllTripletMiner, HardNegativeTripletMiner

# Define the triplet mining loss given:
#   * margin: the margin float value from the triplet loss definition
miner          = AllTripletMiner( .5 ).cuda( )
miner          = HardNegativeTripletMiner( .5 ).cuda( )

# Use the loss in training given:
#   * labels    : array of label ( class ) for each sample of the batch
#   * embeddings: output of the neural network for each sample of the batch
# Returns two values:
#   * loss    : triplet loss value
#   * frac_pos: fraction of positive triplets
#               None ( None HardNegativeTripletMiner )
loss, frac_pos = miner( labels, embeddings )

Example

The repository provides an example application with the MNIST dataset.

 MNIST

References

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

TripletTorch-0.1.3.tar.gz (5.5 kB view details)

Uploaded Source

Built Distribution

TripletTorch-0.1.3-py3-none-any.whl (6.1 kB view details)

Uploaded Python 3

File details

Details for the file TripletTorch-0.1.3.tar.gz.

File metadata

  • Download URL: TripletTorch-0.1.3.tar.gz
  • Upload date:
  • Size: 5.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.4.0 requests-toolbelt/0.9.1 tqdm/4.36.1 CPython/3.6.9

File hashes

Hashes for TripletTorch-0.1.3.tar.gz
Algorithm Hash digest
SHA256 4330ff2348f25ab185175d5b2b1bf27fe86b4ee0868e6ae416f0965b32ac248d
MD5 77c20c423dec7c1a2bebcb5f61c4be45
BLAKE2b-256 e0a169d4667e0d5d2b5939447bbf4944d3e586add97ee67042f357b55d01edff

See more details on using hashes here.

File details

Details for the file TripletTorch-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: TripletTorch-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 6.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.4.0 requests-toolbelt/0.9.1 tqdm/4.36.1 CPython/3.6.9

File hashes

Hashes for TripletTorch-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 d919ac977dfaa58a63fd108adacaa7e34b3e895be2a855fbc2447a8fb983f0a7
MD5 ad8975146d20ab56674f7cc04ee754a9
BLAKE2b-256 94bfff93cde75e4a6d07c52a7ffb49d8cb8d65926ff65ec149a63135be5a1fdf

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page