The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
Project description
PyTorch Metric Learning
Documentation
Google Colab Example
See this notebook for an example of a complete training and testing workflow. View other examples in the examples folder
Benefits of this library
- Ease of use
- Add metric learning to your application with just 2 lines of code in your training loop.
- Mine pairs and triplets with a single function call.
- Flexibility
- Mix and match losses, miners, and trainers in ways that other libraries don't allow.
Installation
Pip:
pip install pytorch-metric-learning
To get the latest dev version:
pip install pytorch-metric-learning==0.9.85.dev0
To install on Windows:
pip install torch===1.4.0 torchvision===0.5.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install pytorch-metric-learning
Conda:
conda install pytorch-metric-learning -c metric-learning
We have recently noticed some sporadic issues with the conda installation, so we recommend installing with pip. You can use pip inside of conda:
conda install pip
pip install pytorch-metric-learning
If you run into problems during installation, please post in this issue.
Benchmark results
See powerful-benchmarker to view benchmark results and to use the benchmarking tool.
Library contents
Losses:
- AngularLoss (Deep Metric Learning with Angular Loss)
- ArcFaceLoss (ArcFace: Additive Angular Margin Loss for Deep Face Recognition)
- CircleLoss (Circle Loss: A Unified Perspective of Pair Similarity Optimization)
- ContrastiveLoss (Dimensionality Reduction by Learning an Invariant Mapping)
- CosFaceLoss (CosFace: Large Margin Cosine Loss for Deep Face Recognition)
- FastAPLoss (Deep Metric Learning to Rank)
- GeneralizedLiftedStructureLoss (Deep Metric Learning via Lifted Structured Feature Embedding)
- IntraPairVarianceLoss (Deep Metric Learning with Tuplet Margin Loss)
- LargeMarginSoftmaxLoss (Large-Margin Softmax Loss for Convolutional Neural Networks)
- MarginLoss (Sampling Matters in Deep Embedding Learning)
- MultiSimilarityLoss (Multi-Similarity Loss with General Pair Weighting for Deep Metric Learning)
- NCALoss (Neighbourhood Components Analysis)
- NormalizedSoftmaxLoss (Classification is a Strong Baseline for DeepMetric Learning)
- NPairsLoss (Improved Deep Metric Learning with Multi-class N-pair Loss Objective)
- NTXentLoss (A Simple Framework for Contrastive Learning of Visual Representations)
- ProxyAnchorLoss (Proxy Anchor Loss for Deep Metric Learning)
- ProxyNCALoss (No Fuss Distance Metric Learning using Proxies)
- SignalToNoiseRatioContrastiveLoss (Signal-to-Noise Ratio: A Robust Distance Metric for Deep Metric Learning)
- SoftTripleLoss (SoftTriple Loss: Deep Metric Learning Without Triplet Sampling)
- SphereFaceLoss (SphereFace: Deep Hypersphere Embedding for Face Recognition)
- TripletMarginLoss (Distance Metric Learning for Large Margin Nearest Neighbor Classification)
- TupletMarginLoss (Deep Metric Learning with Tuplet Margin Loss)
Miners:
- AngularMiner
- BatchHardMiner (In Defense of the Triplet Loss for Person Re-Identification)
- DistanceWeightedMiner (Sampling Matters in Deep Embedding Learning)
- EmbeddingsAlreadyPackagedAsTriplets
- HDCMiner (Hard-Aware Deeply Cascaded Embedding)
- MaximumLossMiner
- MultiSimilarityMiner (Multi-Similarity Loss with General Pair Weighting for Deep Metric Learning)
- PairMarginMiner
- TripletMarginMiner (FaceNet: A Unified Embedding for Face Recognition and Clustering)
Regularizers:
- CenterInvariantRegularizer (Deep Face Recognition with Center Invariant Loss)
- RegularFaceRegularizer (RegularFace: Deep Face Recognition via Exclusive Regularization)
Samplers:
Trainers:
- MetricLossOnly
- TrainWithClassifier
- CascadedEmbeddings (Hard-Aware Deeply Cascaded Embedding)
- DeepAdversarialMetricLearning (Deep Adversarial Metric Learning)
- UnsupervisedEmbeddingsUsingAugmentations
Testers:
Utils:
Base Classes, Mixins, and Wrappers:
- BaseMetricLossFunction
- BaseMiner
- BaseTupleMiner
- BaseSubsetBatchMiner
- BaseWeightRegularizer
- BaseTrainer
- BaseTester
- CrossBatchMemory (Cross-Batch Memory for Embedding Learning)
- GenericPairLoss
- MultipleLosses
- WeightRegularizerMixin
Overview
Let’s try the vanilla triplet margin loss. In all examples, embeddings is assumed to be of size (N, embedding_size), and labels is of size (N).
from pytorch_metric_learning import losses
loss_func = losses.TripletMarginLoss(margin=0.1)
loss = loss_func(embeddings, labels)
Loss functions typically come with a variety of parameters. For example, with the TripletMarginLoss, you can control how many triplets per sample to use in each batch. You can also use all possible triplets within each batch:
loss_func = losses.TripletMarginLoss(triplets_per_anchor="all")
Sometimes it can help to add a mining function:
from pytorch_metric_learning import miners, losses
miner = miners.MultiSimilarityMiner(epsilon=0.1)
loss_func = losses.TripletMarginLoss(margin=0.1)
hard_pairs = miner(embeddings, labels)
loss = loss_func(embeddings, labels, hard_pairs)
In the above code, the miner finds positive and negative pairs that it thinks are particularly difficult. Note that even though the TripletMarginLoss operates on triplets, it’s still possible to pass in pairs. This is because the library automatically converts pairs to triplets and triplets to pairs, when necessary.
In general, all loss functions take in embeddings and labels, with an optional indices_tuple argument (i.e. the output of a miner):
# From BaseMetricLossFunction
def forward(self, embeddings, labels, indices_tuple=None)
And (almost) all mining functions take in embeddings and labels:
# From BaseMiner
def forward(self, embeddings, labels)
For more complex approaches, like deep adversarial metric learning, use one of the trainers.
To check the accuracy of your model, use one of the testers. Which tester should you use? Almost definitely GlobalEmbeddingSpaceTester, because it does what most metric-learning papers do.
Also check out the example Google Colab notebooks.
To learn more about all of the above, see the documentation.
Development
In order to run unit tests do:
pip install -e .[dev]
pytest tests
The first command may fail initially on Windows. In such a case, install torch
by following the official
guide. Proceed to pip install -e .[dev]
afterwards.
Acknowledgements
Facebook AI
Thank you to Ser-Nam Lim at Facebook AI, and my research advisor, Professor Serge Belongie. This project began during my internship at Facebook AI where I received valuable feedback from Ser-Nam, and his team of computer vision and machine learning engineers and research scientists. In particular, thanks to Ashish Shah and Austin Reiter for reviewing my code during its early stages of development.
Open-source repos
This library contains code that has been adapted and modified from the following great open-source repos:
- https://github.com/bnu-wangxun/Deep_Metric
- https://github.com/chaoyuaw/incubator-mxnet/blob/master/example/gluon/embedding_learning
- https://github.com/facebookresearch/deepcluster
- https://github.com/geonm/proxy-anchor-loss
- https://github.com/idstcv/SoftTriple
- https://github.com/kunhe/FastAP-metric-learning
- https://github.com/ronekko/deep_metric_learning
- https://github.com/tjddus9597/Proxy-Anchor-CVPR2020
- http://kaizhao.net/regularface
Contributors
Thanks to the contributors who made pull requests!
Algorithm implementations
Example notebooks
General improvements and bug fixes
Citing this library
If you'd like to cite pytorch-metric-learning in your paper, you can use this bibtex:
@misc{Musgrave2019,
author = {Musgrave, Kevin and Lim, Ser-Nam and Belongie, Serge},
title = {PyTorch Metric Learning},
year = {2019},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/KevinMusgrave/pytorch-metric-learning}},
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for pytorch-metric-learning-0.9.85.dev0.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0beea0521a980a0a78adaf155fd486641a9a3fbfa6514ee3eeed61527485ad4e |
|
MD5 | 3c78fc2001d9b92796a0a375470a7a1b |
|
BLAKE2b-256 | 4a117ce10d5d5efced83577d02b8df94601a476ff33e23e4f0dcb75eac50e1b3 |
Hashes for pytorch_metric_learning-0.9.85.dev0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f471ff30443554a9c9dc6340dc435c5728e6cade0699652e412b4fc410184340 |
|
MD5 | 90d2478214e9f86f67b570ec6489a662 |
|
BLAKE2b-256 | a43b69b7f59d72b3590c3456c1cc0d3f7c9545cab6afb42c342904363608af26 |