Skip to main content

a library so simple you will learn Within An Hour

Project description

logo

Install

pip install wah

Requirements

You might want to manually install PyTorch for GPU computation.

lightning
PyYAML
tensorboard
torch
torchaudio
torchmetrics
torchvision

Model Training

Let's train ResNet50 [1] on CIFAR-10 [2] dataset. First, import the package.

import wah

Second, write your own config.yaml file (which will do everything for you).

num_classes: 10
batch_size: 128
num_workers: 2

epochs: 200
init_lr: 0.1
seed: 0

optimizer: SGD
optimizer_cfg:
  momentum: 0.9
  weight_decay: 5.e-4

lr_scheduler: MultiStepLR
lr_scheduler_cfg:
  milestones: [ 60, 120, 160, ]
  gamma: 0.2
  • num_classes (int) - number of classes in the dataset.

  • batch_size (int) - how many samples per batch to load.

  • num_workers (int) - how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process.

  • epochs (int) - stop training once this number of epochs is reached.

  • init_lr (float) - initial learning rate.

  • seed (int) - seed value for random number generation. Must be a non-negative integer. If a negative integer is provided, no seeding will occur.

  • optimizer (str) - specifies which optimizer to use. Must be one of the optimizers supported in torch.optim.

  • optimizer_cfg - parameters for the specified optimizer. The params parameter does not need to be explicitly provided (automatically initialized).

  • lr_scheduler (str) - specifies which scheduler to use. Must be one of the schedulers supported in torch.optim.lr_scheduler.

  • lr_scheduler_cfg - parameters for the specified scheduler. The optimizer parameter does not need to be explicitly provided (automatically initialized).

Third, load your config.yaml file.

config = wah.load_config(PATH_TO_CONFIG)

Fourth, load your dataloaders.

from torchvision.datasets import CIFAR10

train_dataset = CIFAR10(train=True, ...)
val_dataset = CIFAR10(train=False, ...)

train_dataloader = wah.load_dataloader(
    dataset=train_dataset,
    config=config,
    shuffle=True,
)
val_dataloader = wah.load_dataloader(
    dataset=val_dataset,
    config=config,
    shuffle=False,
)

Fifth, load your model.

from torchvision.models import resnet50

model = resnet50(weights=None, num_classes=10)
model = wah.Wrapper(model, config)

Finally, train your model!

tensorboard_logger = wah.load_tensorboard_logger(
    config=config,
    save_dir=TRAIN_LOG_ROOT,
    name="cifar10-resnet50",
)
lr_monitor = wah.load_lr_monitor()
checkpoint_callback = wah.load_checkpoint_callback(SAVE_CKPT_PER_THIS_EPOCH)

trainer = wah.load_trainer(
    config=config,
    logger=[tensorboard_logger, ],
    callbacks=[lr_monitor, checkpoint_callback, ],
)
trainer.fit(
    model=model,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
)

You can check your train logs by running the following command:

tensorboard --logdir TRAIN_LOG_ROOT

References

[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep Residual Learning for Image Recognition. CVPR, 2016.
[2] Alex Krizhevsky and Geoffrey Hinton. Learning Multiple Layers of Features from Tiny Images. Tech. Rep., University of Toronto, Toronto, Ontario, 2009.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

wah-0.1.0.tar.gz (5.8 kB view hashes)

Uploaded Source

Built Distribution

wah-0.1.0-py3-none-any.whl (6.6 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page