Simple self-supervised contrastive based on based on the ReLIC method
Project description
ReLIC
A PyTorch implementation of a computer vision self-supervised learning method based on Representation Learning via Invariant Causal Mechanisms (ReLIC).
This simple approach is very similar to BYOL and SimCLR. The training technique uses a online and target encoder (EMA) with a simple critic MLP projector, while the instance discrimination loss function resembles the contrastive loss used in SimCLR. The other half of the loss function acts as a regularizer - it includes an invariance penalty, which forces the representations to stay invariant under data augmentations and amplifies intra-class distances.
Results
First, we learned features using SimCLR on the STL10 unsupervised
set. Then, we train a linear classifier on top of the frozen features from SimCLR. The linear model is trained on features extracted from the STL10 train
set and evaluated on the STL10 test
set.
Models are first trained on training subsets - for CIFAR10
50,000 and for STL10
100,000 images. For evaluation, I trained and tested LogisticRegression on:
CIFAR10
- 50,000 train images on 10,000 test images.STL10
- features were learned on 100k unlabeled images. LogReg was trained on 5k train images and evaluated on 8k test images.
Linear probing were evaluated on features extracted from encoders using the scikit LogisticRegression model.
More detailed evaluation steps and results for CIFAR10 and STL10 can be found in the notebooks directory.
Evaulation model | Dataset | Feature Extractor | Encoder | Feature dim | Projection Head dim | Epochs | Top1 % |
---|---|---|---|---|---|---|---|
LogisticRegression | CIFAR10 | ReLIC | ResNet-18 | 512 | 64 | 100 | 71.07 |
LogisticRegression | STL10 | ReLIC | ResNet-18 | 512 | 64 | 100 | 76.10 |
LogisticRegression | STL10 | ReLIC | ResNet-50 | 2048 | 64 | 100 | 80.40 |
Usage
Instalation
$ pip install relic-pytorch
Code currently supports ResNet18, ResNet50 and an experimental version of the EfficientNet model. Supported datasets are STL10 and CIFAR10.
All training is done from scratch.
Examples
CIFAR10
ResNet-18 model was trained with this command:
relic_train --dataset_name "cifar10" --encoder_model_name resnet18 --fp16_precision --tau 5 --gamma 0.99 --alpha 1.0
STL10
ResNet-50 model was trained with this command:
relic_train --dataset_name "stl10" --encoder_model_name resnet50 --fp16_precision
Detailed options
Once the code is setup, run the following command with optinos listed below:
relic_train [args...]⬇️
ReLIC
options:
-h, --help show this help message and exit
--dataset_path DATASET_PATH
Path where datasets will be saved
--dataset_name {stl10,cifar10}
Dataset name
-m {resnet18,resnet50,efficientnet}, --encoder_model_name {resnet18,resnet50,efficientnet}
model architecture: resnet18, resnet50 or efficientnet (default: resnet18)
-save_model_dir SAVE_MODEL_DIR
Path where models
--num_epochs NUM_EPOCHS
Number of epochs for training
-b BATCH_SIZE, --batch_size BATCH_SIZE
Batch size
-lr LEARNING_RATE, --learning_rate LEARNING_RATE
-wd WEIGHT_DECAY, --weight_decay WEIGHT_DECAY
--fp16_precision Whether to use 16-bit precision GPU training.
--proj_out_dim PROJ_OUT_DIM
Projector MLP out dimension
--proj_hidden_dim PROJ_HIDDEN_DIM
Projector MLP hidden dimension
--log_every_n_steps LOG_EVERY_N_STEPS
Log every n steps
--gamma GAMMA Initial EMA coefficient
--tau TAU Softmax temperature
--alpha ALPHA Regularization loss factor
--update_gamma_after_step UPDATE_GAMMA_AFTER_STEP
Update EMA gamma after this step
--update_gamma_every_n_steps UPDATE_GAMMA_EVERY_N_STEPS
Update EMA gamma after this many steps
Citation
@misc{mitrovic2020representation,
title={Representation Learning via Invariant Causal Mechanisms},
author={Jovana Mitrovic and Brian McWilliams and Jacob Walker and Lars Buesing and Charles Blundell},
year={2020},
eprint={2010.07922},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file relic-pytorch-0.1.2.tar.gz
.
File metadata
- Download URL: relic-pytorch-0.1.2.tar.gz
- Upload date:
- Size: 11.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 68b4ceb60960b5891dbe7de58b9ca188d36ec5ce3db93bc9c31ef338b00caa18 |
|
MD5 | 95d16e1f55dbb053344b7bbc8b18a0d8 |
|
BLAKE2b-256 | 2dcc4aebcba887b13d94e3fbf94ba392580f67b16a90c03d6a7b3baa8a907ebb |
File details
Details for the file relic_pytorch-0.1.2-py3-none-any.whl
.
File metadata
- Download URL: relic_pytorch-0.1.2-py3-none-any.whl
- Upload date:
- Size: 11.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b365c17bbe71fcf863836d3af62e3a66e3f6439dc899c5918c201e0968a32b6f |
|
MD5 | f5c463ffa7268bbd2d79c5c60cfdf3c1 |
|
BLAKE2b-256 | a5edae0eb3cba80c0956e56d3ab0a3205b6504a7fb51e5ae7a6bdfb5d039937a |