GeDML is an easy-to-use generalized deep metric learning library, which contains state-of-the-art deep metric learning algorithms and auxiliary modules to build end-to-end compute vision systems
Project description
News
- [2021-9-13]:
v0.0.1
has been released:config.yaml
will be created to store the configuration in the experiment folder.- [2021-9-6]:
v0.0.0
has been released.
Introduction
GeDML is an easy-to-use generalized deep metric learning library, which contains:
- State-of-the-art DML algorithms: We contrain 18+ losses functions and 6+ sampling strategies, and divide these algorithms into three categories (i.e., collectors, selectors, and losses).
- Bridge bewteen DML and SSL: We attempt to bridge the gap between deep metric learning and self-supervised learning through specially designed modules, such as
collectors
. - Auxiliary modules to assist in building: We also encapsulates the upper interface for users to start programs quickly and separates the codes and configs for managing hyper-parameters conveniently.
Installation
Pip
pip install gedml
Quickstart
Please set the environment variable WORKSPACE
first to indicate where to manage your project and download config
which include args.csv
, assert.yaml
, links
, param
, wrapper
.
(Demo of convenient and fast switching between DML and SSL)
Setting launch.json in VS Code
"env": {
"CUDA_VISIBLE_DEVICES": "0"
},
"args": [
"--device", "0",
"--delete_old",
"--batch_size", "180",
"--test_batch_size", "180",
"--setting", "margin_loss",
"--margin_alpha", "1",
"--margin_beta", "0.5",
"--lr", "0.00003",
// "--use_wandb",
]
Initialization
Use ParserWithConvert
to get parameters
>>> from gedml.launcher.misc import ParserWithConvert
>>> csv_path = ...
>>> parser = ParserWithConvert(csv_path=csv_path, name="...")
>>> opt, convert_dict = parser.render()
Use ConfigHandler
to create all objects.
>>> from gedml.launcher.creators import ConfigHandler
>>> link_path = ...
>>> assert_path = ...
>>> param_path = ...
>>> wrapper_path = ...
>>> config_handler = ConfigHandler(
convert_dict=convert_dict,
link_path=link_path,
assert_path=assert_path,
params_path=param_path,
wrapper_path=wrapper_path,
is_confirm_first=True
)
>>> config_handler.get_params_dict()
>>> objects_dict = config_handler.create_all()
Start
Use manager
to automatically call trainer
and tester
.
>>> from gedml.launcher.misc import utils
>>> manager = utils.get_default(objects_dict, "managers")
>>> manager.run()
Or directly use trainer
and tester
.
>>> from gedml.launcher.misc import utils
>>> trainer = utils.get_default(objects_dict, "trainers")
>>> tester = utils.get_default(objects_dict, "testers")
>>> recorder = utils.get_default(objects_dict, "recorders")
# start to train
>>> utils.func_params_mediator(
[objects_dict],
trainer.__call__
)
# start to test
>>> metrics = utils.func_params_mediator(
[
{"recorders": recorder},
objects_dict,
],
tester.__call__
)
Framework
This project is modular in design. The pipeline diagram is as follows:
Code structure
Method
Collectors
method | description |
---|---|
BaseCollector | Base class |
DefaultCollector | Do nothing |
ProxyCollector | Maintain a set of proxies |
MoCoCollector | paper: Momentum Contrast for Unsupervised Visual Representation Learning |
SimSiamCollector | paper: Exploring Simple Siamese Representation Learning |
HDMLCollector | paper: Hardness-Aware Deep Metric Learning |
DAMLCollector | paper: Deep Adversarial Metric Learning |
DVMLCollector | paper: Deep Variational Metric Learning |
Losses
classifier-based
method | description |
---|---|
CrossEntropyLoss | Cross entropy loss for unsupervised methods |
LargeMarginSoftmaxLoss | paper: Large-Margin Softmax Loss for Convolutional Neural Networks |
ArcFaceLoss | paper: ArcFace: Additive Angular Margin Loss for Deep Face Recognition |
CosFaceLoss | paper: CosFace: Large Margin Cosine Loss for Deep Face Recognition |
pair-based
method | description |
---|---|
ContrastiveLoss | paper: Learning a Similarity Metric Discriminatively, with Application to Face Verification |
MarginLoss | paper: Sampling Matters in Deep Embedding Learning |
TripletLoss | paper: Learning local feature descriptors with triplets and shallow convolutional neural networks |
AngularLoss | paper: Deep Metric Learning with Angular Loss |
CircleLoss | paper: Circle Loss: A Unified Perspective of Pair Similarity Optimization |
FastAPLoss | paper: Deep Metric Learning to Rank |
LiftedStructureLoss | paper: Deep Metric Learning via Lifted Structured Feature Embedding |
MultiSimilarityLoss | paper: Multi-Similarity Loss With General Pair Weighting for Deep Metric Learning |
NPairLoss | paper: Improved Deep Metric Learning with Multi-class N-pair Loss Objective |
SignalToNoiseRatioLoss | paper: Signal-To-Noise Ratio: A Robust Distance Metric for Deep Metric Learning |
PosPairLoss | paper: Exploring Simple Siamese Representation Learning |
proxy-based
method | description |
---|---|
ProxyLoss | paper: No Fuss Distance Metric Learning Using Proxies |
ProxyAnchorLoss | paper: Proxy Anchor Loss for Deep Metric Learning |
SoftTripleLoss | paper: SoftTriple Loss: Deep Metric Learning Without Triplet Sampling |
Selectors
method | description |
---|---|
BaseSelector | Base class |
DefaultSelector | Do nothing |
DenseTripletSelector | Select all triples |
DensePairSelector | Select all pairs |
Document
For more information, please refer to:
:book: :point_right: Docs
Some specific guidances:
Configs
We will continually update the optimal parameters of different configs in TsinghuaCloud
Code Reference
- KevinMusgrave / pytorch-metric-learning
- KevinMusgrave / powerful-benchmarker
- Confusezius / Deep-Metric-Learning-Baselines
- facebookresearch / moco
- PatrickHua / SimSiam
- ujjwaltiwari / Deep_Variational_Metric_Learning
- idstcv / SoftTriple
- wzzheng / HDML
- google-research / simclr
- kunhe / FastAP-metric-learning
- wy1iu / LargeMargin_Softmax_Loss
- tjddus9597 / Proxy-Anchor-CVPR2020
- facebookresearch / deit
TODO:
- assert parameters.
- distributed methods and Non-distributed methods!!!
- write github action to automate unit-test, package publish and docs building.
- add cross-validation splits protocol.
Important TODO
- write
DML
toSSL
Demos. - write complete config (easily run).
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.