Run image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet.
Project description
image-classification-jax
Run image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet.
Meant to be simple but good quality. Includes:
- ViT with qk normalization, swiglu, empty registers
- Palm style z-loss (https://arxiv.org/pdf/2204.02311)
- ability to use schedule-free from
optax.contrib
- ability to use PSGD optimizers from
psgd-jax
with hessian calc - datasets currently implemented include cifar10, cifar100, imagenette, and imagenet
Currently no model sharding, only data parallelism (automatically splits batch batch_size/n_devices
).
Installation
pip install image-classification-jax
Usage
Set your wandb key either in your python script or through command line:
export WANDB_API_KEY=<your_key>
Use run_experiment
to run an experiment. Here's how you could run an experiment with
PSGD affine optimizer wrapped with schedule-free:
import optax
from image_classification_jax.run_experiment import run_experiment
from psgd_jax.affine import affine
base_lr = 0.001
warmup = 256
lr = optax.join_schedules(
schedules=[
optax.linear_schedule(0.0, base_lr, warmup),
optax.constant_schedule(base_lr),
],
boundaries=[warmup],
)
psgd_opt = optax.chain(
optax.clip_by_global_norm(1.0),
affine(
lr,
preconditioner_update_probability=1.0,
b1=0.0,
weight_decay=0.0,
max_size_triangular=0,
max_skew_triangular=0,
precond_init_scale=1.0,
),
)
optimizer = optax.contrib.schedule_free(psgd_opt, learning_rate=lr, b1=0.95)
run_experiment(
log_to_wandb=True,
wandb_entity="",
wandb_project="image_classification_jax",
wandb_config_update={ # extra logging info for wandb
"optimizer": "psgd_affine",
"lr": base_lr,
"warmup": warmup,
"b1": 0.95,
"schedule_free": True,
},
global_seed=100,
dataset="cifar10",
batch_size=64,
n_epochs=10,
optimizer=optimizer,
compute_in_bfloat16=False,
l2_regularization=0.0001,
randomize_l2_reg=False,
apply_z_loss=True,
model_type="vit",
n_layers=4,
enc_dim=64,
n_heads=4,
n_empty_registers=0,
dropout_rate=0.0,
using_schedule_free=True, # set to True if optimizer wrapped with schedule_free
psgd_calc_hessian=False, # set to True if using PSGD and want to calc and pass in hessian
psgd_precond_update_prob=1.0,
)
TODO:
- Add SAM, ASAM, Momentum-SAM
- Add loss landscape flatness logging
- Add logging for optimizer output norm, hessian norm
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for image_classification_jax-0.1.2.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 99824d47a84f1f546c35240bfcb9a87d61bd0a8a6cac389b66a6d91ea42f1e55 |
|
MD5 | fe8f79bc220eb27f19cdadee3962d5c6 |
|
BLAKE2b-256 | e3455c9c29fe111ac1a5be26641b417ecd2ed862be52c3a9683ac11684051c99 |
Close
Hashes for image_classification_jax-0.1.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 178f58d470bee0063a6160f8f9ee1b8f609bc067403e4251e29960f0c4492d98 |
|
MD5 | 7bc33591c97e18191e79050864ea24cf |
|
BLAKE2b-256 | 94a6c516db69fc28cd8cbcfdae49742b1f77cc9a9dfe12e2c481482e1c8341dc |