6 projects
distributed-kron
An implementation of PSGD Kron optimizer in JAX/optax for large scale distributed training.
quad-torch
An implementation of PSGD-QUAD optimizer in PyTorch.
kron-torch
An implementation of PSGD Kron optimizer in PyTorch.
psgd-jax
An implementation of PSGD optimizer in JAX.
image-classification-jax
Run image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet.
psgd-torch
None