Run image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet.
Project description
image-classification-jax
Run image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet.
Meant to be simple but good quality. Includes:
- ViT with qk normalization, swiglu, empty registers
- Palm style z-loss (https://arxiv.org/pdf/2204.02311)
- ability to use schedule-free from
optax.contrib
- ability to use PSGD optimizers from
psgd-jax
with hessian calc - datasets currently implemented include cifar10, cifar100, imagenette, and imagenet
Currently no model sharding, only data parallelism (automatically splits batch batch_size/n_devices
).
Installation
pip install image-classification-jax
Usage
Set your wandb key either in your python script or through command line:
export WANDB_API_KEY=<your_key>
Use run_experiment
to run an experiment. Here's how you could run an experiment with
PSGD affine optimizer wrapped with schedule-free:
import optax
from image_classification_jax.run_experiment import run_experiment
from psgd_jax.affine import affine
base_lr = 0.001
warmup = 256
lr = optax.join_schedules(
schedules=[
optax.linear_schedule(0.0, base_lr, warmup),
optax.constant_schedule(base_lr),
],
boundaries=[warmup],
)
psgd_opt = optax.chain(
optax.clip_by_global_norm(1.0),
affine(
lr,
preconditioner_update_probability=1.0,
b1=0.0,
weight_decay=0.0,
max_size_triangular=0,
max_skew_triangular=0,
precond_init_scale=1.0,
),
)
optimizer = optax.contrib.schedule_free(psgd_opt, learning_rate=lr, b1=0.95)
run_experiment(
log_to_wandb=True,
wandb_entity="",
wandb_project="image_classification_jax",
wandb_config_update={ # extra logging info for wandb
"optimizer": "psgd_affine",
"lr": base_lr,
"warmup": warmup,
"b1": 0.95,
"schedule_free": True,
},
global_seed=100,
dataset="cifar10",
batch_size=64,
n_epochs=10,
optimizer=optimizer,
compute_in_bfloat16=False,
l2_regularization=0.0001,
randomize_l2_reg=False,
apply_z_loss=True,
model_type="vit",
n_layers=4,
enc_dim=64,
n_heads=4,
n_empty_registers=0,
dropout_rate=0.0,
using_schedule_free=True, # set to True if optimizer wrapped with schedule_free
psgd_calc_hessian=False, # set to True if using PSGD and want to calc and pass in hessian
psgd_precond_update_prob=1.0,
)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for image_classification_jax-0.1.1.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | e1038889ae321d48a187bcc181c7dfc12c7bd6057836a50a31e9bfea6abfd353 |
|
MD5 | 3e9ff343e05ab5773f2dcd2d169f98a8 |
|
BLAKE2b-256 | de4550e9126b47306cfab7be1e97035fa0f4ad6e945da96727241c49264dad60 |
Close
Hashes for image_classification_jax-0.1.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 796e29a25715104e46de90fd384f545ed4344eb8c22bfa623547d9d69f1c878e |
|
MD5 | 52bbc8cb347db667c6401a8be00fe369 |
|
BLAKE2b-256 | ef07a8dbddabae88918b253e80f654188847d2e52ed950d52f21505519252bcf |