Skip to main content

A simple PyTorch package that includes the most common metric learning layers

Project description

MLL - Metric Learning Layers

MLL is a simple PyTorch package that includes the most common metric learning layers. MLL only includes layers that are not dependent on negative sample mining and therefore drop in replacements for the final linear layer used in classification problems. All layers aim to achieve greater inter-class variance and minimizing intra-class variance. Moreover, all MLL-layers can be used in conjunction with soft-targets (e.g. with Mixup).

The basis of all these layers is the scaled cosine similarity $y = xW * s$ between the $d$-dimensional input vectors (features) $x \in \mathbb{R}^{1 \times d}$ and the $c$ class weights (prototypes, embeddings) $W \in \mathbb{R}^{d \times c}$ where $||x|| = 1$ and $||W_{*, j}|| = 1 ,, \forall j= 1\dots c$ and $s \in \mathbb{R}^+$.

Supported Layers

We currently support the following layers:

You can use multiple sub-centers for all layers except for DeepNCM. If you do not specify a scale, MLL will use the heuristic from AdaCos $s = \sqrt{2} * log(c - 1)$.

Install MLL

Simply run:

pip install metric_learning_layers

Example

import torch
import metric_learning_layers as mll

rnd_batch  = torch.randn(32, 128)
rnd_labels = torch.randint(low=0, high=10, size=(32, ))

arcface = mll.ArcFace(in_features=128, 
                      out_features=10, 
                      num_sub_centers=1, 
                      scale=None, # defaults to AdaCos heuristic
                      trainable_scale=False
                      )

af_out = arcface(rnd_batch, rnd_labels)  # ArcFace requires labels (used to apply the margin)
# af_out: torch.Size([32, 10])

adacos = mll.AdaCos(in_features=128, 
                    out_features=10, 
                    num_sub_centers=1 
                    )

ac_out = arcface(rnd_batch)  # AdaCos does not require labels
# af_out: torch.Size([32, 10])

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

Metric Learning Layers-0.1.5.tar.gz (5.7 kB view details)

Uploaded Source

Built Distribution

Metric_Learning_Layers-0.1.5-py3-none-any.whl (6.5 kB view details)

Uploaded Python 3

File details

Details for the file Metric Learning Layers-0.1.5.tar.gz.

File metadata

  • Download URL: Metric Learning Layers-0.1.5.tar.gz
  • Upload date:
  • Size: 5.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.11.0

File hashes

Hashes for Metric Learning Layers-0.1.5.tar.gz
Algorithm Hash digest
SHA256 5995ddc3ff229205f139031f95077e4cad92d4c463d2d526ba933e9d1156c6da
MD5 1b0fddaca782a7ebe0a1196089d749aa
BLAKE2b-256 6b0e8029917321ae1833164f345db0b97ac1af3084353d51b49f3f7b5beecc1b

See more details on using hashes here.

File details

Details for the file Metric_Learning_Layers-0.1.5-py3-none-any.whl.

File metadata

File hashes

Hashes for Metric_Learning_Layers-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 bf68223af4463d6c64c9b29fe5dabdfc1805142b55b09204acd75f1a0daa6302
MD5 8fa0de87f6de8307d119274fc33af5d8
BLAKE2b-256 6beb4dc9ce691a8a96f6c13eb32336bf0caf8e6410acb1cc056d09dd8c04933b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page