A simple PyTorch package that includes the most common metric learning layers.
Project description
MLL - Metric Learning Layers
MLL is a simple PyTorch package that includes the most common metric learning layers. MLL only includes layers that are not dependent on negative sample mining and therefore drop in replacements for the final linear layer used in classification problems. All layers aim to achieve greater inter-class variance and minimizing intra-class variance. Moreover, all MLL-layers can be used in conjunction with soft-targets (e.g. with Mixup).
The basis of all these layers is the scaled cosine similarity $$y = xW * s$$ between the $d$-dimensional input vectors (features) $x \in \mathbb{R}^{1 \times d}$ and the $c$ class weights (prototypes, embeddings) $W \in \mathbb{R}^{d \times c}$ where $||x|| = 1$ and $||W_{*, j}|| = 1 ,, \forall j= 1\dots c$ and $s \in \mathbb{R}^+$.
Supported Layers
We currently support the following layers:
- NormalizedLinear and ScaledNormalizedLinear
- CosFace
- ArcFace
- AdaCos and FixedAdaCos
- DeepNCM
You can use multiple sub-centers for all layers except for DeepNCM. If you do not specify a scale, MLL will use the heuristic from AdaCos $s = \sqrt{2} * \log{c-1}$.
Install MLL
Simply run:
pip install metric-learning-layers
Example
import torch
import metric_learning_layers as mll
rnd_batch = torch.randn(32, 128)
rnd_labels = torch.randint(low=0, high=10, size=(32, ))
arcface = mll.ArcFace(in_features=128,
out_features=10,
num_sub_centers=1,
scale=None, # defaults to AdaCos heuristic
trainable_scale=False
)
af_out = arcface(rnd_batch, rnd_labels) # ArcFace requires labels (used to apply the margin)
# af_out: torch.Size([32, 10])
adacos = mll.AdaCos(in_features=128,
out_features=10,
num_sub_centers=1
)
ac_out = arcface(rnd_batch) # AdaCos does not require labels
# af_out: torch.Size([32, 10])
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file Metric Learning Layers-0.1.6.tar.gz
.
File metadata
- Download URL: Metric Learning Layers-0.1.6.tar.gz
- Upload date:
- Size: 5.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e5055427b4670ca3600bfa05f3257d4c5e3a2a89fbd85bf10545c59cf245301a |
|
MD5 | 62cf11487717b9da420e47698bbe6657 |
|
BLAKE2b-256 | b149cb8bb937322c6cf78de1ac49437e7094c07fd4cffc0315d22c82847d530f |
File details
Details for the file Metric_Learning_Layers-0.1.6-py3-none-any.whl
.
File metadata
- Download URL: Metric_Learning_Layers-0.1.6-py3-none-any.whl
- Upload date:
- Size: 6.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d197165a4e4aaf7173a7be8b660dfe94ac5e6be62f436055c43125283ebb4cd0 |
|
MD5 | b036d51358e27c3af6c0ab67b225dbbb |
|
BLAKE2b-256 | 8938e4488ec216f01fd0d8afe930d93c2d4cb0e50e79834faf80a43cd04e574a |