Run image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet.
Project description
image-classification-jax
Run image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet.
Meant to be simple but good quality. Includes:
- ViT with qk normalization, swiglu, empty registers
- Palm style z-loss (https://arxiv.org/pdf/2204.02311)
- ability to use schedule-free from
optax.contrib
- ability to use PSGD optimizers from
psgd-jax
with hessian calc - datasets currently implemented include cifar10, cifar100, imagenette, and imagenet
Currently no model sharding, only data parallelism (automatically splits batch batch_size/n_devices
).
Installation
pip install image-classification-jax
Usage
Set your wandb key either in your python script or through command line:
export WANDB_API_KEY=<your_key>
Use run_experiment
to run an experiment. Here's how you could run an experiment with
PSGD affine optimizer wrapped with schedule-free:
import optax
from image_classification_jax.run_experiment import run_experiment
from psgd_jax.affine import affine
base_lr = 0.001
warmup = 256
lr = optax.join_schedules(
schedules=[
optax.linear_schedule(0.0, base_lr, warmup),
optax.constant_schedule(base_lr),
],
boundaries=[warmup],
)
psgd_opt = optax.chain(
optax.clip_by_global_norm(1.0),
affine(
lr,
preconditioner_update_probability=1.0,
b1=0.0,
weight_decay=0.0,
max_size_triangular=0,
max_skew_triangular=0,
precond_init_scale=1.0,
),
)
optimizer = optax.contrib.schedule_free(psgd_opt, learning_rate=lr, b1=0.95)
run_experiment(
log_to_wandb=True,
wandb_entity="",
wandb_project="image_classification_jax",
wandb_config_update={ # extra logging info for wandb
"optimizer": "psgd_affine",
"lr": base_lr,
"warmup": warmup,
"b1": 0.95,
"schedule_free": True,
},
global_seed=100,
dataset="cifar10",
batch_size=64,
n_epochs=10,
optimizer=optimizer,
compute_in_bfloat16=False,
l2_regularization=0.0001,
randomize_l2_reg=False,
apply_z_loss=True,
model_type="vit",
n_layers=4,
enc_dim=64,
n_heads=4,
n_empty_registers=0,
dropout_rate=0.0,
using_schedule_free=True, # set to True if optimizer wrapped with schedule_free
psgd_calc_hessian=False, # set to True if using PSGD and want to calc and pass in hessian
psgd_precond_update_prob=1.0,
)
TODO:
- Add SAM, ASAM, Momentum-SAM
- Add loss landscape flatness logging
- Add logging for optimizer output norm, hessian norm
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for image_classification_jax-0.1.3.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | db1e9ba6eb2b263d2c18a3e4d54ee1ee4b0c107e8189e537855691af0900c716 |
|
MD5 | cd17dad2caeb5f92d245b055642b9f4e |
|
BLAKE2b-256 | 729cf29804d2da4c0ff236bc8cdc3a0951aaefaaccdec3e467c8697bc01e179b |
Close
Hashes for image_classification_jax-0.1.3-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | bd1ad372838a71bf4ee41f109839b5720e6ea01a3644e6bf2d6fc154d47dd323 |
|
MD5 | 03f89b84907f3e45bbb4c0540ada7c97 |
|
BLAKE2b-256 | 51fb9496467fe018d8a8e95e7a13deeb83dc176e18b6583a016e0b508459e085 |