JAX Image Models
JAX Image Models
Future home of JAX Image Models (
jimm). Sibling library of https://github.com/rwightman/pytorch-image-models. Like
jimm will become a collection of JAX based image models w/ pretrained weights focused on transfer learning. The first models to be include here will the Flax Linen JAX adaptation of the MBConv family (EfficientNet, MobileNetV2/V3, etc) https://github.com/rwightman/efficientnet-jax.
jimm will be built while exploring transfer learning and the impact of different augmentation, regularization, supervised, semi-supervised, self-supervised pretraining techniques on the transferability of weights for different target tasks.
Specifically I hope to compare transfer learning on a wide variety of models for a variety of target datasets and tasks across:
- supervised ImageNet-1k training with heavy augmentation and regularization
- supervised ImageNet-1k training with light augmentation and regularization
- supervised 'larger' dataset (ImageNet-21k, OpenImages) training w/ light (may try heavy since this isn't JFT-300M) augmentation and regularization
- semi-supervised and self-supervised pretraining (SimCLRV2, BYOL, FixMatch, etc) on ImageNet-1k/21k, OpenImages
There is already a body of research work on the subject of transfer learning. Much of my work here will not be breaking new ground but providing me with an opportunity to learn and do what I do best -- refine and improve. Some papers in this space:
- A Large-scale Study of Representation Learning with the Visual Task Adaptation Benchmark - https://arxiv.org/abs/1910.04867
- Big Transfer (BiT): General Visual Representation Learning- https://arxiv.org/abs/1912.11370
- On Robustness and Transferability of Convolutional Neural Networks - https://arxiv.org/abs/2007.08558
- Which Model to Transfer? Finding the Needle in the Growing Haystack - https://arxiv.org/abs/2010.06402
- Self-supervised Pre-training with Hard Examples Improves Visual Representations - https://arxiv.org/abs/2012.13493
- Do Adversarially Robust ImageNet Models Transfer Better? - https://arxiv.org/abs/2007.08489
- How Useful is Self-Supervised Pretraining for Visual Tasks? - https://arxiv.org/abs/2003.14323 ... please file an issue if you have ideas for additional papers
Papers on self or semi-supervised techniques that I plan to explore
- Bootstrap your own latent: A new approach to self-supervised Learning - https://arxiv.org/abs/2006.07733
- Big Self-Supervised Models are Strong Semi-Supervised Learners - https://arxiv.org/abs/2006.10029
- FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence - https://arxiv.org/abs/2001.07685
- Self-training with Noisy Student improves ImageNet classification - https://arxiv.org/abs/1911.04252 ... please file an issue if you have ideas for additional papers
The scope of 'transfer learning' will initially cover fine-tuning (head replaced with newly initialized task-specific head, 0..N layers frozen). I may explore linear classifier or low-shot techniques from semi/self-supervised pretraining later.
I'm currently planning which datasets to select for transfer learning benchmarks. I'm hoping to explore a cross section of natural image datasets (that don't overlap with imagenet), and other (medical, spectrogram, industrial inspection, etc). I'd like to cover a different cross section of datasets than usual, I hope to find some interesting options from various Kaggle challenges, etc that are available with compatible licenses. Dataset suggestions welcome.
The development of
jimm does not mean I'm abandoning
timm. I will build the models in a manner that allows easy movements of weights back and forth. I'm interested in building in JAX because a) I've enjoyed my JAX exploration so far b) I have some TPU credits that allows more compute intensive exploration than my open source training budget would allow. I will be augmenting that with my local NVIDIA GPU resources.
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
|Filename, size||File type||Python version||Upload date||Hashes|
|Filename, size jimm-0.0.1-py3-none-any.whl (7.1 kB)||File type Wheel||Python version py3||Upload date||Hashes View|
|Filename, size jimm-0.0.1.tar.gz (3.5 kB)||File type Source||Python version None||Upload date||Hashes View|