Skip to main content

Triplet Loss Utils for Pytorch Library.

Project description

Triplet Loss Utility for Pytorch Library.

TripletTorch

TripletTorch is a small pytorch utility for triplet loss projects. It provides simple way to create custom triplet datasets and common triplet mining loss techniques.

Install

Install the module using the pip utility ( may require to run as sudo ).

pip3 install triplettorch

Usage

Triplet Dataset

from triplettorch import TripletDataset

# Create a triplet dataset given:
#   * labels  : array of label ( class ) for each sample of the dataset
#   * data_fn : method to access data for a given index in the dataset
#   * size    : number of samples in the dataset
#   * n_sample: number of sample per draw ( to increase probability to
#               contain valid triplets in a batch )
# Do not forget to concatenate batch dimension and sample dimension
# when used with a DataLoader as TripletDataset[ idx ] returns a
# ( batch_size, n_sample, ... ) dimension tensor for labels and data
dataset = TripletDataset( labels, data_fn, size, n_sample )

Triplet Mining

from triplettorch import AllTripletMiner, HardNegativeTripletMiner

# Define the triplet mining loss given:
#   * margin: the margin float value from the triplet loss definition
miner          = AllTripletMiner( .5 ).cuda( )
miner          = HardNegativeTripletMiner( .5 ).cuda( )

# Use the loss in training given:
#   * labels    : array of label ( class ) for each sample of the batch
#   * embeddings: output of the neural network for each sample of the batch
# Returns two values:
#   * loss    : triplet loss value
#   * frac_pos: fraction of positive triplets
#               None ( None HardNegativeTripletMiner )
loss, frac_pos = miner( labels, embeddings )

Example

The repository provides an example application with the MNIST dataset.

 MNIST

References

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

TripletTorch-0.1.2.tar.gz (5.5 kB view details)

Uploaded Source

Built Distribution

TripletTorch-0.1.2-py3-none-any.whl (6.1 kB view details)

Uploaded Python 3

File details

Details for the file TripletTorch-0.1.2.tar.gz.

File metadata

  • Download URL: TripletTorch-0.1.2.tar.gz
  • Upload date:
  • Size: 5.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.36.1 CPython/3.6.8

File hashes

Hashes for TripletTorch-0.1.2.tar.gz
Algorithm Hash digest
SHA256 f8a73c3eae348e3a56c2fafe3d7b20d962392b77604ee6ba3a9691eff87d7fd6
MD5 61357944c38f8321cc52a92fce184941
BLAKE2b-256 094253d62e2c287baf452e8d3f05814dfaaf745e76e20e13582cde77ccc08bd5

See more details on using hashes here.

File details

Details for the file TripletTorch-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: TripletTorch-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 6.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.36.1 CPython/3.6.8

File hashes

Hashes for TripletTorch-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 1825094af01e067d7afc2eed46fa58e3ceb51ff9f32febe31486152be9d8e991
MD5 5424f3076bc867253b923c65e1bb110b
BLAKE2b-256 07d31c9f8221e173ed491c604472876e533beccb62ca2e20fd841e8c7d78b203

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page