Skip to main content

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.

Project description

Logo

PyPi version PyPi stats Anaconda version Anaconda downloads

Commit activity License

Losses unit tests Miners unit tests Reducers unit tests Regularizers unit tests

Samplers unit tests Testers unit tests Trainers unit tests Utils unit tests

News

January 12: v0.9.96 greatly increases the flexibility of the testers and AccuracyCalculator. See the release notes

December 10: v0.9.95 includes a new tuple miner, BatchEasyHardMiner. See the release notes

November 6: v0.9.94 has minor bug fixes and improvements. Release notes

Documentation

Google Colab Examples

See the examples folder for notebooks you can download or run on Google Colab.

PyTorch Metric Learning Overview

This library contains 9 modules, each of which can be used independently within your existing codebase, or combined together for a complete train/test workflow.

high_level_module_overview

How loss functions work

Using losses and miners in your training loop

Let’s initialize a plain TripletMarginLoss:

from pytorch_metric_learning import losses
loss_func = losses.TripletMarginLoss()

To compute the loss in your training loop, pass in the embeddings computed by your model, and the corresponding labels. The embeddings should have size (N, embedding_size), and the labels should have size (N), where N is the batch size.

# your training loop
for i, (data, labels) in enumerate(dataloader):
	optimizer.zero_grad()
	embeddings = model(data)
	loss = loss_func(embeddings, labels)
	loss.backward()
	optimizer.step()

The TripletMarginLoss computes all possible triplets within the batch, based on the labels you pass into it. Anchor-positive pairs are formed by embeddings that share the same label, and anchor-negative pairs are formed by embeddings that have different labels.

Sometimes it can help to add a mining function:

from pytorch_metric_learning import miners, losses
miner = miners.MultiSimilarityMiner()
loss_func = losses.TripletMarginLoss()

# your training loop
for i, (data, labels) in enumerate(dataloader):
	optimizer.zero_grad()
	embeddings = model(data)
	hard_pairs = miner(embeddings, labels)
	loss = loss_func(embeddings, labels, hard_pairs)
	loss.backward()
	optimizer.step()

In the above code, the miner finds positive and negative pairs that it thinks are particularly difficult. Note that even though the TripletMarginLoss operates on triplets, it’s still possible to pass in pairs. This is because the library automatically converts pairs to triplets and triplets to pairs, when necessary.

Customizing loss functions

Loss functions can be customized using distances, reducers, and regularizers. In the diagram below, a miner finds the indices of hard pairs within a batch. These are used to index into the distance matrix, computed by the distance object. For this diagram, the loss function is pair-based, so it computes a loss per pair. In addition, a regularizer has been supplied, so a regularization loss is computed for each embedding in the batch. The per-pair and per-element losses are passed to the reducer, which (in this diagram) only keeps losses with a high value. The averages are computed for the high-valued pair and element losses, and are then added together to obtain the final loss.

high_level_loss_function_overview

Now here's an example of a customized TripletMarginLoss:

from pytorch_metric_learning.distances import CosineSimilarity
from pytorch_metric_learning.reducers import ThresholdReducer
from pytorch_metric_learning.regularizers import LpRegularizer
from pytorch_metric_learning import losses
loss_func = losses.TripletMarginLoss(distance = CosineSimilarity(), 
				     reducer = ThresholdReducer(high=0.3), 
			 	     embedding_regularizer = LpRegularizer())

This customized triplet loss has the following properties:

  • The loss will be computed using cosine similarity instead of Euclidean distance.
  • All triplet losses that are higher than 0.3 will be discarded.
  • The embeddings will be L2 regularized.

Using loss functions for unsupervised / self-supervised learning

The TripletMarginLoss is an embedding-based or tuple-based loss. This means that internally, there is no real notion of "classes". Tuples (pairs or triplets) are formed at each iteration, based on the labels it receives. The labels don't have to represent classes. They simply need to indicate the positive and negative relationships between the embeddings. Thus, it is easy to use these loss functions for unsupervised or self-supervised learning.

For example, the code below is a simplified version of the augmentation strategy commonly used in self-supervision. The dataset does not come with any labels. Instead, the labels are created in the training loop, solely to indicate which embeddings are positive pairs.

# your training for-loop
for i, data in enumerate(dataloader):
	optimizer.zero_grad()
	embeddings = your_model(data)
	augmented = your_model(your_augmentation(data))
	labels = torch.arange(embeddings.size(0))

	embeddings = torch.cat([embeddings, augmented], dim=0)
	labels = torch.cat([labels, labels], dim=0)

	loss = loss_func(embeddings, labels)
	loss.backward()
	optimizer.step()

If you're interested in MoCo-style self-supervision, take a look at the MoCo on CIFAR10 notebook. It uses CrossBatchMemory to implement the momentum encoder queue, which means you can use any tuple loss, and any tuple miner to extract hard samples from the queue.

Highlights of the rest of the library

  • For a convenient way to train your model, take a look at the trainers.
  • Want to test your model's accuracy on a dataset? Try the testers.
  • To compute the accuracy of an embedding space directly, use AccuracyCalculator.

If you're short of time and want a complete train/test workflow, check out the example Google Colab notebooks.

To learn more about all of the above, see the documentation.

Installation

Required PyTorch version

  • pytorch-metric-learning >= v0.9.90 requires torch >= 1.6
  • pytorch-metric-learning < v0.9.90 doesn't have a version requirement, but was tested with torch >= 1.2

Pip

pip install pytorch-metric-learning

To get the latest dev version:

pip install pytorch-metric-learning --pre

To install on Windows:

pip install torch===1.6.0 torchvision===0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install pytorch-metric-learning

To install with evaluation and logging capabilities (This will install the unofficial pypi version of faiss-gpu):

pip install pytorch-metric-learning[with-hooks]

To install with evaluation and logging capabilities (CPU) (This will install the unofficial pypi version of faiss-cpu):

pip install pytorch-metric-learning[with-hooks-cpu]

Conda

conda install pytorch-metric-learning -c metric-learning -c pytorch

To use the testing module, you'll need faiss, which can be installed via conda as well. See the installation instructions for faiss.

Library contents

Distances

Name Reference Papers
CosineSimilarity
DotProductSimilarity
LpDistance
SNRDistance Signal-to-Noise Ratio: A Robust Distance Metric for Deep Metric Learning

Losses

Name Reference Papers
AngularLoss Deep Metric Learning with Angular Loss
ArcFaceLoss ArcFace: Additive Angular Margin Loss for Deep Face Recognition
CircleLoss Circle Loss: A Unified Perspective of Pair Similarity Optimization
ContrastiveLoss Dimensionality Reduction by Learning an Invariant Mapping
CosFaceLoss - CosFace: Large Margin Cosine Loss for Deep Face Recognition
- Additive Margin Softmax for Face Verification
FastAPLoss Deep Metric Learning to Rank
GeneralizedLiftedStructureLoss In Defense of the Triplet Loss for Person Re-Identification
IntraPairVarianceLoss Deep Metric Learning with Tuplet Margin Loss
LargeMarginSoftmaxLoss Large-Margin Softmax Loss for Convolutional Neural Networks
LiftedStructreLoss Deep Metric Learning via Lifted Structured Feature Embedding
MarginLoss Sampling Matters in Deep Embedding Learning
MultiSimilarityLoss Multi-Similarity Loss with General Pair Weighting for Deep Metric Learning
NCALoss Neighbourhood Components Analysis
NormalizedSoftmaxLoss - NormFace: L2 Hypersphere Embedding for Face Verification
- Classification is a Strong Baseline for DeepMetric Learning
NPairsLoss Improved Deep Metric Learning with Multi-class N-pair Loss Objective
NTXentLoss - Representation Learning with Contrastive Predictive Coding
- Momentum Contrast for Unsupervised Visual Representation Learning
- A Simple Framework for Contrastive Learning of Visual Representations
ProxyAnchorLoss Proxy Anchor Loss for Deep Metric Learning
ProxyNCALoss No Fuss Distance Metric Learning using Proxies
SignalToNoiseRatioContrastiveLoss Signal-to-Noise Ratio: A Robust Distance Metric for Deep Metric Learning
SoftTripleLoss SoftTriple Loss: Deep Metric Learning Without Triplet Sampling
SphereFaceLoss SphereFace: Deep Hypersphere Embedding for Face Recognition
TripletMarginLoss Distance Metric Learning for Large Margin Nearest Neighbor Classification
TupletMarginLoss Deep Metric Learning with Tuplet Margin Loss

Miners

Name Reference Papers
AngularMiner
BatchEasyHardMiner Improved Embeddings with Easy Positive Triplet Mining
BatchHardMiner In Defense of the Triplet Loss for Person Re-Identification
DistanceWeightedMiner Sampling Matters in Deep Embedding Learning
EmbeddingsAlreadyPackagedAsTriplets
HDCMiner Hard-Aware Deeply Cascaded Embedding
MaximumLossMiner
MultiSimilarityMiner Multi-Similarity Loss with General Pair Weighting for Deep Metric Learning
PairMarginMiner
TripletMarginMiner FaceNet: A Unified Embedding for Face Recognition and Clustering

Reducers

Name Reference Papers
AvgNonZeroReducer
ClassWeightedReducer
DivisorReducer
DoNothingReducer
MeanReducer
ThresholdReducer

Regularizers

Name Reference Papers
CenterInvariantRegularizer Deep Face Recognition with Center Invariant Loss
LpRegularizer
RegularFaceRegularizer RegularFace: Deep Face Recognition via Exclusive Regularization
SparseCentersRegularizer SoftTriple Loss: Deep Metric Learning Without Triplet Sampling
ZeroMeanRegularizer Signal-to-Noise Ratio: A Robust Distance Metric for Deep Metric Learning

Samplers

Name Reference Papers
MPerClassSampler
TuplesToWeightsSampler
FixedSetOfTriplets

Trainers

Name Reference Papers
MetricLossOnly
TrainWithClassifier
CascadedEmbeddings Hard-Aware Deeply Cascaded Embedding
DeepAdversarialMetricLearning Deep Adversarial Metric Learning
UnsupervisedEmbeddingsUsingAugmentations
TwoStreamMetricLoss

Testers

Name Reference Papers
GlobalEmbeddingSpaceTester
WithSameParentLabelTester
GlobalTwoStreamEmbeddingSpaceTester

Utils

Name Reference Papers
AccuracyCalculator
HookContainer
InferenceModel
TorchInitWrapper
DistributedLossWrapper
DistributedMinerWrapper
LogitGetter

Base Classes, Mixins, and Wrappers

Name Reference Papers
CrossBatchMemory Cross-Batch Memory for Embedding Learning
GenericPairLoss
MultipleLosses
MultipleReducers
EmbeddingRegularizerMixin
WeightMixin
WeightRegularizerMixin
BaseDistance
BaseMetricLossFunction
BaseMiner
BaseTupleMiner
BaseSubsetBatchMiner
BaseReducer
BaseRegularizer
BaseTrainer
BaseTester

Benchmark results

See powerful-benchmarker to view benchmark results and to use the benchmarking tool.

Development

Unit tests can be run with the default unittest library:

python -m unittest discover

You can specify the test datatypes and test device as environment variables. For example, to test using float32 and float64 on the CPU:

TEST_DTYPES=float32,float64 TEST_DEVICE=cpu python -m unittest discover

To run a single test file instead of the entire test suite, specify the file name:

python -m unittest tests/losses/test_angular_loss.py

Code is formatted using black and isort:

pip install black isort
./format_code.sh

Acknowledgements

Contributors

Thanks to the contributors who made pull requests!

Algorithm implementations + useful features

Example notebooks

General improvements and bug fixes

Facebook AI

Thank you to Ser-Nam Lim at Facebook AI, and my research advisor, Professor Serge Belongie. This project began during my internship at Facebook AI where I received valuable feedback from Ser-Nam, and his team of computer vision and machine learning engineers and research scientists. In particular, thanks to Ashish Shah and Austin Reiter for reviewing my code during its early stages of development.

Open-source repos

This library contains code that has been adapted and modified from the following great open-source repos:

Logo

Thanks to Jeff Musgrave for designing the logo.

Citing this library

If you'd like to cite pytorch-metric-learning in your paper, you can use this bibtex:

@misc{musgrave2020pytorch,
    title={PyTorch Metric Learning},
    author={Kevin Musgrave and Serge Belongie and Ser-Nam Lim},
    year={2020},
    eprint={2008.09164},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch-metric-learning-0.9.97.dev1.tar.gz (72.6 kB view hashes)

Uploaded Source

Built Distribution

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page