Skip to main content

Cross validation made easy in Lightning

Project description

PL Crossvalidate

license Tests codecov

Cross validation in pytorch lightning made easy :]

Just import the specialized trainer from pl_crossvalidate instead of pytorch_lightning and you are set

# To distinguish from the original trainer the new trainer is called KFoldTrainer by default
from pl_crossvalidate import KFoldTrainer as Trainer

# Normal Lightning module
model = MyModel(...)

# Use a Lightning datamodule or training dataloader
datamodule = MyDatamodule(...)

# New trainer takes all original arguments + three new for controling the cross validation
trainer = Trainer(
    num_folds=5,  # number of folds to do
    shuffle=False,  # if samples should be shuffled before splitting
    stratified=False,  # if splitting should be done in a stratified manner

# Returns a dict of stats over the different splits
cross_val_stats = trainer.cross_validate(model, datamodule=datamodule)

# Additionally, we can construct an ensemble from the K trained models
ensemble_model = trainer.create_ensemble(model)

💻 Installation

pip install pl-crossvalidate

Or latest version from github

pip install

Requires torch>=2.0, lightning>=2.0 and scikit-learn>=1.0.

🤔 Cross-validation: why?

The core functionality of machine learning algorithms is that they are able to learn from data. Therefore, it is very interesting to ask the question: how well does our algorithms actually learn?. This is in abstract question, because it requires us to define what well means. One interpretation of this question is an algorithms ability to generalize e.g. a model that generalizes well have actually learned something meaningfull.

The mathematical definition of the generalization error/expected loss/risk is given by

where is some function denotes the loss function and is the joint probability distribution between and . This is the theoretical error an algorithm will do on some unobserved dataset. The problem with this definition is that we cannot compute it, due to being unknown and even if we knew it the integral is intractable. The best we therefore can do is an approximation of the generalization error:

which measures the error that our function does on datapoints measured by loss function . This function we can compute (just think of this as your normal loss function) and we even know that

Namely that approximation of the generalization error will become the true generalization error if we just evaluate it on enough data. But how does all this related to cross-validation you may ask? The problem with the above is that is not a fixed function, but data-dependent function i.e. . Thus, the above approximation will only converge if and refers to different sets of data points. This is where cross-validation strategies comes into play.

Hold out K-fold

In general we consider two viable strategies for selecting the (validation) and (training) set: hold-out validation and K-fold cross validation. In hold out we create a separate independent set of data to evaluate our training on. This is easily done in native pytorch-lightning by implementing the validation_step method. For K-fold we cut our data into K equally large chunks and then we iteratively train on K-1 folds and evaluate on the remaining 1 fold, repeating this K times. In general K-fold gives a better approximation of the generalization error than hold-out, but at the expense of requiring you to train K models.

🗒️ Some notes

  • For the .cross_validate method to work, we in addition to the standard set of method in lightning that need to be implemented (training_step and configure_optimizers) we also requires the test_step method to be implemented, as we use this method evaluating the hold out set. We do not rely on the validation_step method as your models training may be dependent on the validation set (for example if you use early stopping) and your validation set will therefore not be truly separated from the training.

  • To do the splitting in cross validation we need the total number of data points in your dataset. For this reason, we require that your dataset implements the __len__ method.

  • Cross validation is always done sequentially, even if the device you are training on in principal could fit parallel training on multiple folds at the same time. We try to figure out in the future if we can parallelize the process.

  • Logging can be a bit weird. Logging of training progress is essentially not important to cross-validation, but that does not mean that it is interesting to track. The cross-validation method will hijack the version attribute of any logger attached to the trainer and set the logging directory to f"{version}/fold_{fold_index}".

  • Stratified splitting assume that we can extract a 1D label vector from your dataset.

    • If your dataset has an labels attribute, we will use that as the labels

    • If the attribute does not exist, we manually iterate over your dataset trying to extract the labels when creating the splits (this is done as part of .setup phase of the datamodule). By default we assume that given a batch the labels can be found as the second argument e.g. batch[1]. You can adjust this by importing the specialized KFoldDataModule and changing the label_extractor attribute. For example, if your batches are dictionaries instead you can do something like this:

      from pl_crossvalidate import KFoldDataModule, KFoldTrainer
      model = ...
      trainer = KFoldTrainer(...)
      datamodule = KFoldDataModule(
          num_folds, shuffle, stratified,  # these should match how the trainer is initialized
      # change the label extractor function, such that it will return the labels for a given batch
      datamodule.label_extractor = lambda batch: batch['y']
      trainer.cross_validate(model, datamodule=datamodule)

😃 Bibtex

If you want to cite the framework feel free to use this:

    title={PL Crossvalidate},
    author={Nicki S. Detlefsen},
    journal={GitHub. Note:},

Project details

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pl_crossvalidate-0.1.0.tar.gz (52.6 kB view hashes)

Uploaded source

Built Distribution

pl_crossvalidate-0.1.0-py3-none-any.whl (19.3 kB view hashes)

Uploaded py3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page