Skip to main content

Training wheels, side rails, and helicopter parent for your Deep Learning projects using Pytorch Lightning

Project description


Training wheels, side rails, and helicopter parent for your Deep Learning projects in PyTorch Lightning.

pip install ride

Zero-boilerplate AI research

Though PyTorch Lightning helps remove a lot of boilerplate code, writing and testing Deep Learning models still includes setting up

  • Trainer
  • Checkpointing
  • Metrics
  • Train-val-test step methods
  • Finetuning schemes
  • Hyperparameter search
  • Main function
  • Command-line interface
  • ...

This project is an audacious attempt at disposing of the remaining boilerplate by providing good battle-tested defaults with minimal coding.

Everything you find here is highly opinionated and was first and foremost an attempt at generalising personal research boiler-plate. On the other hand, it might be just right, and if not, it's highly extendable and forkable. Suggestions and pull requests are always welcome!

Programming model

Did you ever take a peek to the source code of the LightningModule? This core class of the Pytorch Lightning library makes heavy use of Mixins and multiple inheritance to group functionalities and "inject" them in the LightningModule.

In ride we build up our modules the same way, mixing in functionality by inheriting from multiple base classes in our Model definition.

Model definition

Below, we have the complete code for a simple classifier on the MNIST dataset:

# simple_classifier.py
import torch
import ride
import numpy as np
from .examples import MnistDataset


class SimpleClassifier(
    ride.RideModule,
    ride.Lifecycle, 
    ride.SgdOneCycleOptimizer, 
    ride.TopKAccuracyMetric(1,3),
    MnistDataset,
):
    def __init__(self, hparams):
        # `self.input_shape` and `self.output_shape` were injected via `MnistDataset`
        self.l1 = torch.nn.Linear(np.prod(self.input_shape), self.hparams.hidden_dim)
        self.l2 = torch.nn.Linear(self.hparams.hidden_dim, self.output_shape)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.l1(x))
        x = torch.relu(self.l2(x))
        return x

    @staticmethod
    def configs():
        c = ride.Configs()
        c.add(
            name="hidden_dim",
            type=int,
            default=128,
            strategy="choice",
            choices=[128, 256, 512, 1024],
            description="Number of hidden units.",
        )
        return c


if __name__ == "__main__":
    ride.Main(SimpleClassifier).argparse()

That's it! So what's going on, and aren't we missing a bunch of code?

All of the usual boiler-plate code has been mixed in using multiple inheritance:

  • RideModule is a base-module which includes pl.LightningModule and makes some behind-the-scenes python-magic work. For instance, it modifies your __init__ function to automatically initiate all the mixins correctly.
  • ClassificationLifecycle mixes in training_step, validation_step, and test_step alongside a loss_fn with cross-entropy.
  • SgdOneCycleOptimizer mixes in the configure_optimizers function with SGD and OneCycleLR scheduler.
  • MnistDataset mixes in train_dataloader, val_dataloader, and test_dataloader functions for the MNIST dataset. Dataset mixins always provide input_shape and output_shape attributes, which are handy for defining the networking structure as seen in __init__.
  • TopKAccuracyMetric adds top1acc and top3acc metrics, which can be used for checkpointing and benchmarking.

In addition to inheriting lifecycle functions etc., the mixins also add configs to your module (powered by co-rider). These define all of the configurable (hyper)parameters including their

  • type
  • default value
  • description in plain text (reflected in command-line interface),
  • space defines accepted input range
  • strategy specifies how hyperparameter-search tackles the parameter.

Configs specific to the SimpleClassifier can be added by overloading the configs methods as shown in the example.

The final piece of sorcery is the Main class, which adds a complete command-line interface.

Command-line interface 💻

Let's check out the command-line interface:

$ python simple_classifier.py --help
...

Flow:
  Commands that control the top-level flow of the programme.

  --hparamsearch        Run hyperparameter search. The best hyperparameters
                        will be used for subsequent lifecycle methods
  --train               Run model training
  --validate            Run model evaluation on validation set
  --test                Run model evaluation on test set
  --profile_model       Profile the model

General:
  Settings that apply to the programme in general.

  --id ID               Identifier for the run. If not specified, the current
                        timestamp will be used (Default: 202101011337)
  --seed SEED           Global random seed (Default: 123)
  --logging_backend {tensorboard,wandb}
                        Type of experiment logger (Default: tensorboard)
  ...

Pytorch Lightning:
  Settings inherited from the pytorch_lightning.Trainer
  ...
  --gpus GPUS           number of gpus to train on (int) or which GPUs to
                        train on (list or str) applied per node
  ...

Hparamsearch:
  Settings associated with hyperparameter optimisation
  ...

Module:
  Settings associated with the Module
  --loss {mse_loss,l1_loss,nll_loss,cross_entropy,binary_cross_entropy,...}
                        Loss function used during optimisation. 
                        (Default: cross_entropy)
  --batch_size BATCH_SIZE
                        Dataloader batch size. (Default: 64)
  --num_workers NUM_WORKERS
                        Number of CPU workers to use for dataloading.
                        (Default: 10)
  --learning_rate LEARNING_RATE
                        Learning rate. (Default: 0.1)
  --weight_decay WEIGHT_DECAY
                        Weight decay. (Default: 1e-05)
  --momentum MOMENTUM   Momentum. (Default: 0.9)
  --hidden_dim HIDDEN_DIM {128, 256, 512, 1024}
                        Number of hidden units. (Defualt: 128)
  ...

Whew, there's a lot going on there (a bunch was even omitted ...)!

First, there are flags for controlling the programme flow (e.g. whether to run hparamsearch or training), then some general parameters (id, seed, etc.), all the parameters from Pytorch Lightning, hparamsearch-related arguments, and finally the Module-specific arguments, which we defined in or mixed into the SimpleClassifier.

Training and testing

$ python simple_classifier.py --train --test --learning_rate 0.01 --hidden_dim 256

Hyperparameter optimization

If we want to perform hyperparameter optimisation across four gpus, we can run:

$ python simple_classifier.py --hparamsearch --gpus 4

Curretly, we use Ray Tune and the ASHA algorithm under the hood.

Model profiling

You can check the timing and FLOPs of the model with:

$ python simple_classifier.py --profile_model

Environment

Per default, ride projects are oriented around the current working directory and will save logs in the ~/logs folders, and cache to ~/.cache.

This behaviour can be overloaded by changing of the following environment variables (defaults noted):

ROOT_PATH="~/"
CACHE_PATH=".cache"
DATASETS_PATH="datasets"  # Dir relative to ROOT_PATH
LOGS_PATH="logs"          # Dir relative to ROOT_PATH
RUN_LOGS_PATH="run_logs"  # Dir relative to LOGS_PATH
TUNE_LOGS_PATH="tune_logs"# Dir relative to LOGS_PATH
LOG_LEVEL="INFO"          # One of "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"

Bibtex

If you end up using ride for your research and feel like citing it, here's a BibTex:

@article{hedegaard2021ride,
  title={Ride},
  author={Lukas Hedegaard},
  journal={GitHub. Note: https://github.com/LukasHedegaard/ride},
  year={2021}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ride-0.3.2.tar.gz (48.2 kB view hashes)

Uploaded Source

Built Distribution

ride-0.3.2-py3-none-any.whl (61.2 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page