JAX Image Models

# JAX Image Models

Future home of JAX Image Models (jimm). Sibling library of https://github.com/rwightman/pytorch-image-models. Like timm, jimm will become a collection of JAX based image models w/ pretrained weights focused on transfer learning. The first models to be include here will the Flax Linen JAX adaptation of the MBConv family (EfficientNet, MobileNetV2/V3, etc) https://github.com/rwightman/efficientnet-jax.

jimm will be built while exploring transfer learning and the impact of different augmentation, regularization, supervised, semi-supervised, self-supervised pretraining techniques on the transferability of weights for different target tasks.

Specifically I hope to compare transfer learning on a wide variety of models for a variety of target datasets and tasks across:

• supervised ImageNet-1k training with heavy augmentation and regularization
• supervised ImageNet-1k training with light augmentation and regularization
• supervised 'larger' dataset (ImageNet-21k, OpenImages) training w/ light (may try heavy since this isn't JFT-300M) augmentation and regularization
• semi-supervised and self-supervised pretraining (SimCLRV2, BYOL, FixMatch, etc) on ImageNet-1k/21k, OpenImages

There is already a body of research work on the subject of transfer learning. Much of my work here will not be breaking new ground but providing me with an opportunity to learn and do what I do best -- refine and improve. Some papers in this space:

Papers on self or semi-supervised techniques that I plan to explore

The scope of 'transfer learning' will initially cover fine-tuning (head replaced with newly initialized task-specific head, 0..N layers frozen). I may explore linear classifier or low-shot techniques from semi/self-supervised pretraining later.

I'm currently planning which datasets to select for transfer learning benchmarks. I'm hoping to explore a cross section of natural image datasets (that don't overlap with imagenet), and other (medical, spectrogram, industrial inspection, etc). I'd like to cover a different cross section of datasets than usual, I hope to find some interesting options from various Kaggle challenges, etc that are available with compatible licenses. Dataset suggestions welcome.

The development of jimm does not mean I'm abandoning timm. I will build the models in a manner that allows easy movements of weights back and forth. I'm interested in building in JAX because a) I've enjoyed my JAX exploration so far b) I have some TPU credits that allows more compute intensive exploration than my open source training budget would allow. I will be augmenting that with my local NVIDIA GPU resources.

## Project details

Uploaded source
Uploaded py3