Skip to main content

Run image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet.

Project description

image-classification-jax

Run image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet.

Meant to be simple but good quality. Includes:

  • ViT with qk normalization, swiglu, empty registers
  • Palm style z-loss (https://arxiv.org/pdf/2204.02311)
  • ability to use schedule-free from optax.contrib
  • datasets currently implemented include cifar10, cifar100, imagenette, and imagenet

Currently no model sharding, only data parallelism (automatically splits batch batch_size/n_devices).

Installation

pip install image-classification-jax

Usage

Set your wandb key either in your python script or through command line:

export WANDB_API_KEY=<your_key>

Use run_experiment to run an experiment. Here's how you could run an experiment with PSGD affine optimizer wrapped with schedule-free:

import optax
from image_classification_jax.run_experiment import run_experiment
from psgd_jax.affine import affine

base_lr = 0.001
warmup = 256
lr = optax.join_schedules(
    schedules=[
        optax.linear_schedule(0.0, base_lr, warmup),
        optax.constant_schedule(base_lr),
    ],
    boundaries=[warmup],
)

psgd_opt = optax.chain(
    optax.clip_by_global_norm(1.0),
    affine(
        lr,
        preconditioner_update_probability=1.0,
        b1=0.0,
        weight_decay=0.0,
        max_size_triangular=0,
        max_skew_triangular=0,
        precond_init_scale=1.0,
    ),
)

optimizer = optax.contrib.schedule_free(psgd_opt, learning_rate=lr, b1=0.95)

run_experiment(
    log_to_wandb=True,
    wandb_entity="",
    wandb_project="image_classification_jax",
    wandb_config_update={  # extra logging info for wandb
        "optimizer": "psgd_affine",
        "lr": base_lr,
        "warmup": warmup,
        "b1": 0.95,
        "schedule_free": True,
    },
    global_seed=100,
    dataset="cifar10",
    batch_size=64,
    n_epochs=10,
    optimizer=optimizer,
    compute_in_bfloat16=False,
    apply_z_loss=True,
    model_type="vit",
    n_layers=4,
    enc_dim=64,
    n_heads=4,
    n_empty_registers=0,
    dropout_rate=0.0,
    using_schedule_free=True,  # set to True if optimizer wrapped with schedule_free
)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

image_classification_jax-0.1.4.tar.gz (17.9 kB view details)

Uploaded Source

Built Distribution

image_classification_jax-0.1.4-py3-none-any.whl (21.3 kB view details)

Uploaded Python 3

File details

Details for the file image_classification_jax-0.1.4.tar.gz.

File metadata

  • Download URL: image_classification_jax-0.1.4.tar.gz
  • Upload date:
  • Size: 17.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.9

File hashes

Hashes for image_classification_jax-0.1.4.tar.gz
Algorithm Hash digest
SHA256 35db2de3ffa15dd6b29304118ea3f0b34510bc8bd413233331252fe24458064f
MD5 e980e0cca17ffc864537810e7b82e801
BLAKE2b-256 8ae3edc9fc4f3274e37e83aef6effa4393a9cb3948b0da5414df1e67c7eb6d44

See more details on using hashes here.

File details

Details for the file image_classification_jax-0.1.4-py3-none-any.whl.

File metadata

File hashes

Hashes for image_classification_jax-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 1e2e004ee4fa18ce385d5eb1eac8edde66ae8d23b1dd0ad758e8e0b0c326d899
MD5 72656b7a7ec8a601f155f687466aacf4
BLAKE2b-256 a94ce1d8f73faf937f9750607d0395bfdcf49f63ad1332c33b5b16b876b1c496

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page