Skip to main content

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.

Project description

PyTorch Metric Learning

PyPi version PyPi stats

Commit activity License

News

June 20: v0.9.87 comes with some major changes that may cause your existing code to break. See the release notes for details.

Documentation

View the documentation here

Google Colab Examples

See the examples folder for notebooks that show entire train/test workflows with logging and model saving.

Benefits of this library

  1. Ease of use
    • Add metric learning to your application with just 2 lines of code in your training loop.
    • Mine pairs and triplets with a single function call.
  2. Flexibility
    • Mix and match losses, miners, and trainers in ways that other libraries don't allow.

Installation

Pip

pip install pytorch-metric-learning

To get the latest dev version:

pip install pytorch-metric-learning==0.9.88

To install on Windows:

pip install torch===1.4.0 torchvision===0.5.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install pytorch-metric-learning

To install with evaluation and logging capabilities (This will install the unofficial pypi version of faiss-gpu):

pip install pytorch-metric-learning[with-hooks]

To install with evaluation and logging capabilities (CPU) (This will install the unofficial pypi version of faiss-cpu):

pip install pytorch-metric-learning[with-hooks-cpu]

Conda

conda install pytorch-metric-learning -c metric-learning -c pytorch

To use the testing module, you'll need faiss, which can be installed via conda as well. See the installation instructions for faiss.

Overview

Let’s try the vanilla triplet margin loss. In all examples, embeddings is assumed to be of size (N, embedding_size), and labels is of size (N).

from pytorch_metric_learning import losses
loss_func = losses.TripletMarginLoss(margin=0.1)
loss = loss_func(embeddings, labels) # in your training loop

Loss functions typically come with a variety of parameters. For example, with the TripletMarginLoss, you can control how many triplets per sample to use in each batch. You can also use all possible triplets within each batch:

loss_func = losses.TripletMarginLoss(triplets_per_anchor="all")

Sometimes it can help to add a mining function:

from pytorch_metric_learning import miners, losses
miner = miners.MultiSimilarityMiner(epsilon=0.1)
loss_func = losses.TripletMarginLoss(margin=0.1)
hard_pairs = miner(embeddings, labels) # in your training loop
loss = loss_func(embeddings, labels, hard_pairs)

In the above code, the miner finds positive and negative pairs that it thinks are particularly difficult. Note that even though the TripletMarginLoss operates on triplets, it’s still possible to pass in pairs. This is because the library automatically converts pairs to triplets and triplets to pairs, when necessary.

Here's what the above examples look like in a typical training loop:

from pytorch_metric_learning import miners, losses
miner = miners.MultiSimilarityMiner(epsilon=0.1)
loss_func = losses.TripletMarginLoss(margin=0.1)

# borrowed from https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
for i, data in enumerate(trainloader, 0):
    inputs, labels = data
    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    embeddings = net(inputs)
    hard_pairs = miner(embeddings, labels)
    loss = loss_func(embeddings, labels, hard_pairs)
    loss.backward()
    optimizer.step()

For more complex approaches, like deep adversarial metric learning, use one of the trainers.

To check the accuracy of your model, use one of the testers. Which tester should you use? Almost definitely GlobalEmbeddingSpaceTester, because it does what most metric-learning papers do.

Also check out the example Google Colab notebooks.

To learn more about all of the above, see the documentation.

Library contents

Losses:

Miners:

Reducers:

Regularizers:

Samplers:

Trainers:

Testers:

Utils:

Base Classes, Mixins, and Wrappers:

Benchmark results

See powerful-benchmarker to view benchmark results and to use the benchmarking tool.

Development

In order to run unit tests do:

pip install -e .[dev]
pytest tests

The first command may fail initially on Windows. In such a case, install torch by following the official guide. Proceed to pip install -e .[dev] afterwards.

Acknowledgements

Contributors

Thanks to the contributors who made pull requests!

Algorithm implementations

Example notebooks

General improvements and bug fixes

Facebook AI

Thank you to Ser-Nam Lim at Facebook AI, and my research advisor, Professor Serge Belongie. This project began during my internship at Facebook AI where I received valuable feedback from Ser-Nam, and his team of computer vision and machine learning engineers and research scientists. In particular, thanks to Ashish Shah and Austin Reiter for reviewing my code during its early stages of development.

Open-source repos

This library contains code that has been adapted and modified from the following great open-source repos:

Citing this library

If you'd like to cite pytorch-metric-learning in your paper, you can use this bibtex:

@misc{Musgrave2019,
  author = {Musgrave, Kevin and Lim, Ser-Nam and Belongie, Serge},
  title = {PyTorch Metric Learning},
  year = {2019},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/KevinMusgrave/pytorch-metric-learning}},
}

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch-metric-learning-0.9.89.dev0.tar.gz (55.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_metric_learning-0.9.89.dev0-py3-none-any.whl (83.7 kB view details)

Uploaded Python 3

File details

Details for the file pytorch-metric-learning-0.9.89.dev0.tar.gz.

File metadata

  • Download URL: pytorch-metric-learning-0.9.89.dev0.tar.gz
  • Upload date:
  • Size: 55.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.6.0.post20191030 requests-toolbelt/0.9.1 tqdm/4.38.0 CPython/3.7.5

File hashes

Hashes for pytorch-metric-learning-0.9.89.dev0.tar.gz
Algorithm Hash digest
SHA256 670953d12046084eb5c1e8560fb7846162daaa6be4410a57b1b765d497aa1937
MD5 37d5153f0e17f0846f3aedad09af2a32
BLAKE2b-256 008959415a58605d9b1ffb8cc2c08a5295ebeb53c7077e66b61028ce71456b81

See more details on using hashes here.

File details

Details for the file pytorch_metric_learning-0.9.89.dev0-py3-none-any.whl.

File metadata

  • Download URL: pytorch_metric_learning-0.9.89.dev0-py3-none-any.whl
  • Upload date:
  • Size: 83.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.6.0.post20191030 requests-toolbelt/0.9.1 tqdm/4.38.0 CPython/3.7.5

File hashes

Hashes for pytorch_metric_learning-0.9.89.dev0-py3-none-any.whl
Algorithm Hash digest
SHA256 8beea8a7d7bbace43faa196d21193e9bf3acc129b3a20866f4f34d1762552e5a
MD5 90001bdb0ac962c25cf585082580d87f
BLAKE2b-256 2ddbe0c888c71d3e879290a6bd59d43cb88269f85c0d8234c236fc34608ce6fe

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page