Skip to main content

Self Supervised Learning Algorithms with Fastai

Project description

Self Supervised Learning with Fastai

Implementation of popular SOTA self-supervised learning algorithms as Fastai Callbacks.

CI PyPI DOI

Install

pip install self-supervised

Documentation

Please read the documentation here.

To go back to github repo please click here.

Algorithms

Please read the papers or blog posts before getting started with an algorithm, you may also check out documentation page of each algorithm to get a better understanding.

Here are the list of implemented self_supervised.vision algorithms:

Here are the list of implemented self_supervised.multimodal algorithms:

  • CLIP
  • CLIP-MoCo (No paper, own idea)

For vision algorithms all models from timm and fastai can be used as encoders.

For multimodal training currently CLIP supports ViT-B/32 and ViT-L/14, following best architectures from the paper.

Simple Usage

Vision

SimCLR

from self_supervised.vision.simclr import *
dls = get_dls(resize, bs)
# encoder = create_encoder("xresnet34", n_in=3, pretrained=False) # a fastai encoder
encoder = create_encoder("tf_efficientnet_b4_ns", n_in=3, pretrained=False) # a timm encoder
model = create_simclr_model(encoder, hidden_size=2048, projection_size=128)
aug_pipelines = get_simclr_aug_pipelines(size=size)
learn = Learner(dls,model,cbs=[SimCLR(aug_pipelines, temp=0.07)])
learn.fit_flat_cos(100, 1e-2)

MoCo

from self_supervised.vision.moco import *
dls = get_dls(resize, bs)
# encoder = create_encoder("xresnet34", n_in=3, pretrained=False) # a fastai encoder
encoder = create_encoder("tf_efficientnet_b4_ns", n_in=3, pretrained=False) # a timm encoder
model = create_moco_model(encoder, hidden_size=2048, projection_size=128)
aug_pipelines = get_moco_aug_pipelines(size=size)
learn = Learner(dls, model,cbs=[MOCO(aug_pipelines=aug_pipelines, K=128)])
learn.fit_flat_cos(100, 1e-2)

BYOL

from self_supervised.vision.byol import *
dls = get_dls(resize, bs)
# encoder = create_encoder("xresnet34", n_in=3, pretrained=False) # a fastai encoder
encoder = create_encoder("tf_efficientnet_b4_ns", n_in=3, pretrained=False) # a timm encoder
model = create_byol_model(encoder, hidden_size=2048, projection_size=128)
aug_pipelines = get_byol_aug_pipelines(size=size)
learn = Learner(dls, model,cbs=[BYOL(aug_pipelines=aug_pipelines)])
learn.fit_flat_cos(100, 1e-2)

SWAV

from self_supervised.vision.swav import *
dls = get_dls(resize, bs)
encoder = create_encoder("xresnet34", n_in=3, pretrained=False) # a fastai encoder
encoder = create_encoder("tf_efficientnet_b4_ns", n_in=3, pretrained=False) # a timm encoder
model = create_swav_model(encoder, hidden_size=2048, projection_size=128)
aug_pipelines = get_swav_aug_pipelines(num_crops=[2,6],
                                       crop_sizes=[128,96], 
                                       min_scales=[0.25,0.05],
                                       max_scales=[1.0,0.3])
learn = Learner(dls, model, cbs=[SWAV(aug_pipelines=aug_pipelines, crop_assgn_ids=[0,1], K=bs*2**6, queue_start_pct=0.5)])
learn.fit_flat_cos(100, 1e-2)

Barlow Twins

from self_supervised.vision.simclr import *
dls = get_dls(resize, bs)
# encoder = create_encoder("xresnet34", n_in=3, pretrained=False) # a fastai encoder
encoder = create_encoder("tf_efficientnet_b4_ns", n_in=3, pretrained=False) # a timm encoder
model = create_barlow_twins_model(encoder, hidden_size=2048, projection_size=128)
aug_pipelines = get_barlow_twins_aug_pipelines(size=size)
learn = Learner(dls,model,cbs=[BarlowTwins(aug_pipelines, lmb=5e-3)])
learn.fit_flat_cos(100, 1e-2)

DINO

from self_supervised.models.vision_transformer import *
from self_supervised.vision.dino import *
dls = get_dls(resize, bs)

deits16 = MultiCropWrapper(deit_small(patch_size=16, drop_path_rate=0.1))
dino_head = DINOHead(deits16.encoder.embed_dim, 2**16, norm_last_layer=True)
student_model = nn.Sequential(deits16,dino_head)

deits16 = MultiCropWrapper(deit_small(patch_size=16))
dino_head = DINOHead(deits16.encoder.embed_dim, 2**16, norm_last_layer=True)
teacher_model = nn.Sequential(deits16,dino_head)

dino_model = DINOModel(student_model, teacher_model)
aug_pipelines = get_dino_aug_pipelines(num_crops=[2,6],
                                       crop_sizes=[128,96], 
                                       min_scales=[0.25,0.05],
                                       max_scales=[1.0,0.3])
 learn = Learner(dls,model,cbs=[DINO(aug_pipelines=aug_pipelines)])
learn.fit_flat_cos(100, 1e-2)

Multimodal

CLIP

from self_supervised.multimodal.clip import *
dls = get_dls(...)
clip_tokenizer = ClipTokenizer()
vitb32_config_dict = vitb32_config(224, clip_tokenizer.context_length, clip_tokenizer.vocab_size)
clip_model = CLIP(**vitb32_config_dict, checkpoint=False, checkpoint_nchunks=0)
learner = Learner(dls, clip_model, loss_func=noop, cbs=[CLIPTrainer()])
learn.fit_flat_cos(100, 1e-2)

CLIP-MoCo

from self_supervised.multimodal.clip_moco import *
dls = get_dls(...)
clip_tokenizer = ClipTokenizer()
vitb32_config_dict = vitb32_config(224, clip_tokenizer.context_length, clip_tokenizer.vocab_size)
clip_model = CLIPMOCO(K=4096,m=0.999, **vitb32_config_dict, checkpoint=False, checkpoint_nchunks=0)
learner = Learner(dls, clip_model, loss_func=noop, cbs=[CLIPMOCOTrainer()])
learn.fit_flat_cos(100, 1e-2)

ImageWang Benchmarks

All of the algorithms implemented in this library have been evaluated in ImageWang Leaderboard.

In overall superiority of the algorithms are as follows SwAV > MoCo > BYOL > SimCLR in most of the benchmarks. For details you may inspect the history of ImageWang Leaderboard through github.

BarlowTwins is still under testing on ImageWang.

It should be noted that during these experiments no hyperparameter selection/tuning was made beyond using learn.lr_find() or making sanity checks over data augmentations by visualizing batches. So, there is still space for improvement and overall rankings of the alogrithms may change based on your setup. Yet, the overall rankings are on par with the papers.

Contributing

Contributions and or requests for new self-supervised algorithms are welcome. This repo will try to keep itself up-to-date with recent SOTA self-supervised algorithms.

Before raising a PR please create a new branch with name <self-supervised-algorithm>. You may refer to previous notebooks before implementing your Callback.

Please refer to sections Developers Guide, Abbreviations Guide, and Style Guide from https://docs.fast.ai/dev-setup and note that same rules apply for this library.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

self_supervised-1.0.4.tar.gz (32.8 kB view details)

Uploaded Source

Built Distribution

self_supervised-1.0.4-py3-none-any.whl (41.6 kB view details)

Uploaded Python 3

File details

Details for the file self_supervised-1.0.4.tar.gz.

File metadata

  • Download URL: self_supervised-1.0.4.tar.gz
  • Upload date:
  • Size: 32.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/32.0 requests/2.24.0 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.51.0 importlib-metadata/4.6.4 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.8.5

File hashes

Hashes for self_supervised-1.0.4.tar.gz
Algorithm Hash digest
SHA256 a91c77db247ccc117c8c8340910154bb70fd55301989b9bf4cf40c7d3b87afda
MD5 3b81cd7e8bc89c264ced01dc8b07b341
BLAKE2b-256 7abf746aa7d21dd7311ba4deb3a13dbc9bba866603d5f915caa3d9da34c70fd5

See more details on using hashes here.

File details

Details for the file self_supervised-1.0.4-py3-none-any.whl.

File metadata

  • Download URL: self_supervised-1.0.4-py3-none-any.whl
  • Upload date:
  • Size: 41.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/32.0 requests/2.24.0 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.51.0 importlib-metadata/4.6.4 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.8.5

File hashes

Hashes for self_supervised-1.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 30e143649cbd3a57348a42c37b8a2e38b8bafef8052a1aa68cc478572a68597c
MD5 ad2084e719d3c0aa0d0a268b009a0ffb
BLAKE2b-256 a0419c713411f0020054c480cd414a5bc8f0a15d9fc322ebf464f0bbc37b9107

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page