Skip to main content

A simple PyTorch package that includes the most common metric learning layers.

Project description

MLL - Metric Learning Layers

MLL is a simple PyTorch package that includes the most common metric learning layers. MLL only includes layers that are not dependent on negative sample mining and therefore drop in replacements for the final linear layer used in classification problems. All layers aim to achieve greater inter-class variance and minimizing intra-class variance. Moreover, all MLL-layers can be used in conjunction with soft-targets (e.g. with Mixup).

The basis of all these layers is the scaled cosine similarity $$y = xW * s$$ between the $d$-dimensional input vectors (features) $x \in \mathbb{R}^{1 \times d}$ and the $c$ class weights (prototypes, embeddings) $W \in \mathbb{R}^{d \times c}$ where $||x|| = 1$ and $||W_{*, j}|| = 1 ,, \forall j= 1\dots c$ and $s \in \mathbb{R}^+$.

Supported Layers

We currently support the following layers:

You can use multiple sub-centers for all layers except for DeepNCM. If you do not specify a scale, MLL will use the heuristic from AdaCos $s = \sqrt{2} * \log{c-1}$.

Install MLL

Simply run:

pip install metric-learning-layers

Example

import torch
import metric_learning_layers as mll


rnd_batch  = torch.randn(32, 128)
rnd_labels = torch.randint(low=0, high=10, size=(32, ))

arcface = mll.ArcFace(in_features=128, 
                      out_features=10, 
                      num_sub_centers=1, 
                      scale=None, # defaults to AdaCos heuristic
                      trainable_scale=False
                      )

af_out = arcface(rnd_batch, rnd_labels)  # ArcFace requires labels (used to apply the margin)
# af_out: torch.Size([32, 10])

adacos = mll.AdaCos(in_features=128, 
                    out_features=10, 
                    num_sub_centers=1 
                    )

ac_out = arcface(rnd_batch)  # AdaCos does not require labels
# af_out: torch.Size([32, 10])

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

Metric Learning Layers-0.1.6.tar.gz (5.7 kB view details)

Uploaded Source

Built Distribution

Metric_Learning_Layers-0.1.6-py3-none-any.whl (6.5 kB view details)

Uploaded Python 3

File details

Details for the file Metric Learning Layers-0.1.6.tar.gz.

File metadata

File hashes

Hashes for Metric Learning Layers-0.1.6.tar.gz
Algorithm Hash digest
SHA256 e5055427b4670ca3600bfa05f3257d4c5e3a2a89fbd85bf10545c59cf245301a
MD5 62cf11487717b9da420e47698bbe6657
BLAKE2b-256 b149cb8bb937322c6cf78de1ac49437e7094c07fd4cffc0315d22c82847d530f

See more details on using hashes here.

File details

Details for the file Metric_Learning_Layers-0.1.6-py3-none-any.whl.

File metadata

File hashes

Hashes for Metric_Learning_Layers-0.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 d197165a4e4aaf7173a7be8b660dfe94ac5e6be62f436055c43125283ebb4cd0
MD5 b036d51358e27c3af6c0ab67b225dbbb
BLAKE2b-256 8938e4488ec216f01fd0d8afe930d93c2d4cb0e50e79834faf80a43cd04e574a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page