Skip to main content

The easiest way to use metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.

Project description

PyTorch Metric Learning

PyPi version PyPi stats

Anaconda version Anaconda last updated Anaconda downloads

Commit activity License

Documentation

View the documentation here

Benefits of this library

  1. Ease of use
    • Add metric learning to your application with just 2 lines of code in your training loop.
    • Mine pairs and triplets with a single function call.
  2. Flexibility
    • Mix and match losses, miners, and trainers in ways that other libraries don't allow.

Installation:

Conda:

conda install pytorch-metric-learning -c metric-learning

Pip:

pip install pytorch-metric-learning

Benchmark results

See powerful-benchmarker to view benchmark results and to use the benchmarking tool.

Currently implemented classes:

Loss functions:

Mining functions:

Regularizers:

Samplers:

Training methods:

Testing methods:

Overview

Let’s try the vanilla triplet margin loss. In all examples, embeddings is assumed to be of size (N, embedding_size), and labels is of size (N).

from pytorch_metric_learning import losses
loss_func = losses.TripletMarginLoss(margin=0.1)
loss = loss_func(embeddings, labels)

Loss functions typically come with a variety of parameters. For example, with the TripletMarginLoss, you can control how many triplets per sample to use in each batch. You can also use all possible triplets within each batch:

loss_func = losses.TripletMarginLoss(triplets_per_anchor="all")

Sometimes it can help to add a mining function:

from pytorch_metric_learning import miners, losses
miner = miners.MultiSimilarityMiner(epsilon=0.1)
loss_func = losses.TripletMarginLoss(margin=0.1)
hard_pairs = miner(embeddings, labels)
loss = loss_func(embeddings, labels, hard_pairs)

In the above code, the miner finds positive and negative pairs that it thinks are particularly difficult. Note that even though the TripletMarginLoss operates on triplets, it’s still possible to pass in pairs. This is because the library automatically converts pairs to triplets and triplets to pairs, when necessary.

In general, all loss functions take in embeddings and labels, with an optional indices_tuple argument (i.e. the output of a miner):

# From BaseMetricLossFunction
def forward(self, embeddings, labels, indices_tuple=None)

And (almost) all mining functions take in embeddings and labels:

# From BaseMiner
def forward(self, embeddings, labels)

For more complex approaches, like deep adversarial metric learning, use one of the trainers.

To check the accuracy of your model, use one of the testers. Which tester should you use? Almost definitely GlobalEmbeddingSpaceTester, because it does what most metric-learning papers do.

Also check out the example scripts. Each one shows how to set up models, optimizers, losses etc for a particular trainer.

To learn more about all of the above, see the documentation.

Acknowledgements

Thank you to Ser-Nam Lim at Facebook AI, and my research advisor, Professor Serge Belongie. This project began during my internship at Facebook AI where I received valuable feedback from Ser-Nam, and his team of computer vision and machine learning engineers and research scientists.

Citing this library

If you'd like to cite pytorch-metric-learning in your paper, you can use this bibtex:

@misc{Musgrave2019,
  author = {Musgrave, Kevin and Lim, Ser-Nam and Belongie, Serge},
  title = {PyTorch Metric Learning},
  year = {2019},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/KevinMusgrave/pytorch-metric-learning}},
}

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch-metric-learning-0.9.66.tar.gz (35.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_metric_learning-0.9.66-py3-none-any.whl (55.5 kB view details)

Uploaded Python 3

File details

Details for the file pytorch-metric-learning-0.9.66.tar.gz.

File metadata

  • Download URL: pytorch-metric-learning-0.9.66.tar.gz
  • Upload date:
  • Size: 35.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.6.0.post20191030 requests-toolbelt/0.9.1 tqdm/4.38.0 CPython/3.7.5

File hashes

Hashes for pytorch-metric-learning-0.9.66.tar.gz
Algorithm Hash digest
SHA256 07cf2befe83c93d20ef8356c37454f04e41434e060cdfeff967b30af7cbc5651
MD5 a63bf694b8f33efbf4243d5ef7d6fc81
BLAKE2b-256 f958996967ff01cd20b7b2765bdeb1b0488d76b965ea0cf355835021239a0e19

See more details on using hashes here.

File details

Details for the file pytorch_metric_learning-0.9.66-py3-none-any.whl.

File metadata

  • Download URL: pytorch_metric_learning-0.9.66-py3-none-any.whl
  • Upload date:
  • Size: 55.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.6.0.post20191030 requests-toolbelt/0.9.1 tqdm/4.38.0 CPython/3.7.5

File hashes

Hashes for pytorch_metric_learning-0.9.66-py3-none-any.whl
Algorithm Hash digest
SHA256 dfec1fdf8fb9d5b4963b82e2c0835c6d073c182564faef9edd2370491ee767ca
MD5 e8fd7e85364775ac9dbf6d940da43337
BLAKE2b-256 2c73e41001801ceef6c98b5ceb00462bdba9bff605a4ab49bca763251145e92b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page