Run image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet.
Project description
image-classification-jax
Run image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet.
Installation
pip install image-classification-jax
Usage
Set your wandb key either in your python script or through command line:
export WANDB_API_KEY=<your_key>
Use run_experiment
to run an experiment. The following example uses the xmat
optimizer from psgd-jax
wrapped in schedule-free.
import optax
from image_classification_jax.run_experiment import run_experiment
from psgd_jax.xmat import xmat
lr = optax.join_schedules(
schedules=[
optax.linear_schedule(0.0, 0.01, 256),
optax.constant_schedule(0.01),
],
boundaries=[256],
)
optimizer = optax.contrib.schedule_free(xmat(lr, b1=0.0), learning_rate=lr, b1=0.95)
run_experiment(
log_to_wandb=True,
wandb_entity="",
wandb_project="image_classification_jax",
wandb_config_update={
"optimizer": "psgd_xmat",
"schedule_free": True,
"learning_rate": 0.01,
"warmup_steps": 256,
"b1": 0.95,
},
global_seed=100,
dataset="cifar10",
batch_size=64,
n_epochs=10,
optimizer=optimizer,
compute_in_bfloat16=False,
l2_regularization=1e-4,
randomize_l2_reg=False,
apply_z_loss=True,
model_type="vit",
n_layers=12,
enc_dim=768,
n_heads=12,
n_empty_registers=0,
dropout_rate=0.0,
using_schedule_free=True,
psgd_calc_hessian=True,
psgd_precond_update_prob=0.1,
)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for image_classification_jax-0.1.0.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | dd6f739f29d5d3d856d1fc32f1f6e9bd445cbb8eed3ea5bfc231960a457a47f2 |
|
MD5 | afe187c77fd289592e89c39e9ddde43e |
|
BLAKE2b-256 | be1ecbd0f0e1a3347ddc3009c1bbb16e97082a8335d27a445dafcbe51cc199af |
Close
Hashes for image_classification_jax-0.1.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3ca080875e5a4e77acd4b1792d262a7e39aa00b5b416f61e9dbc1d0837fa0b76 |
|
MD5 | b0bccc21efd6daa03be3c93bf8fabc97 |
|
BLAKE2b-256 | 04a22c7f18a247b8e64c4a3f3b3682cf0854b351ecb6f5ebdd4b87d925cfc1f2 |