A flexible and extensible metric learning library, written in PyTorch.
Project description
pytorch_metric_learning
See this Google Spreadsheet for benchmark results (in progress).
See powerful_benchmarker to use the benchmarking tool.
Loss functions implemented:
- angular
- contrastive
- lifted structure
- margin
- multi similarity
- n pairs
- nca
- proxy nca
- triplet margin
- more to be added
Mining functions implemented:
- distance weighted sampling
- hard aware cascaded mining
- maximum loss miner
- multi similarity miner
- pair margin miner
- more to be added
Training methods implemented:
- metric loss only
- training with classifier
- cascaded embeddings
- deep adversarial metric learning
- more to be added
Installation:
pip install pytorch_metric_learning
Overview
Use a loss function by itself
from pytorch_metric_learning import losses
loss_func = losses.TripletMarginLoss(normalize_embeddings=False, margin=0.1)
loss = loss_func(embeddings, labels)
Or combine miners and loss functions, regardless of whether they mine or compute loss using pairs or triplets. Pairs are converted to triplets when necessary, and vice versa.
from pytorch_metric_learning import miners, losses
miner = miners.MultiSimilarityMiner(epsilon=0.1)
loss_func = losses.TripletMarginLoss(normalize_embeddings=False, margin=0.1)
hard_pairs = miner(embeddings, labels)
loss = loss_func(embeddings, labels, hard_pairs)
Train using more advanced approaches, like deep adversarial metric learning. For example:
from pytorch_metric_learning import trainers
# Set up your models, optimizers, loss functions etc.
models = {"trunk": your_trunk_model,
"embedder": your_embedder_model,
"generator": your_negative_generator}
optimizers = {"trunk_optimizer": your_trunk_optimizer,
"embedder_optimizer": your_embedder_optimizer,
"generator_optimizer": your_negative_generator_optimizer}
loss_funcs = {"metric_loss": losses.AngularNPairs(alpha=35),
"synth_loss": losses.Angular(alpha=35),
"g_adv_loss": losses.Angular(alpha=35)}
mining_funcs = {}
loss_weights = {"metric_loss": 1,
"classifier_loss": 0,
"synth_loss": 0.1,
"g_adv_loss": 0.1,
"g_hard_loss": 0.1,
"g_reg_loss": 0.1}
# Create trainer object
trainer = trainers.DeepAdversarialMetricLearning(
models=models,
optimizers=optimizers,
batch_size=120,
loss_funcs=loss_funcs,
mining_funcs=mining_funcs,
num_epochs=50,
iterations_per_epoch=100,
dataset=your_dataset,
loss_weights=loss_weights
)
trainer.train()
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for pytorch_metric_learning-0.9.23.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4b2563a2193a8c96a342599e5d3de584fadc25385ec683475f520bf0c44689c5 |
|
MD5 | 3b175945d4ebd2212b6a31ac7d66fe00 |
|
BLAKE2b-256 | 4d26f505beefd56b50f7c7e43686879f2e1f6c4b883186d13f3f87d0423cb093 |
Close
Hashes for pytorch_metric_learning-0.9.23-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | b2cd84c3b47c6bb106b5cbc8c302501bc6185ca6b4a83b168ed91b84718f6872 |
|
MD5 | d3a0d0c906f151fb6d0ecd1d95b4c7c0 |
|
BLAKE2b-256 | 66dd3928274971231203f7d78c10e5f19097ec8b468794ef94620aacce050e2e |