Cross validation made easy in Lightning
Project description
PL Crossvalidate
Cross validation in pytorch lightning made easy :]
Just import the specialized trainer from pl_crossvalidate
instead of pytorch_lightning
and you are set
# To distinguish from the original trainer the new trainer is called KFoldTrainer by default
from pl_crossvalidate import KFoldTrainer as Trainer
# Normal Lightning module
model = MyModel(...)
# Use a Lightning datamodule or training dataloader
datamodule = MyDatamodule(...)
# New trainer takes all original arguments + three new for controling the cross validation
trainer = Trainer(
num_folds=5, # number of folds to do
shuffle=False, # if samples should be shuffled before splitting
stratified=False, # if splitting should be done in a stratified manner
accelerator=...,
callbacks=...,
...
)
# Returns a dict of stats over the different splits
cross_val_stats = trainer.cross_validate(model, datamodule=datamodule)
# Additionally, we can construct an ensemble from the K trained models
ensemble_model = trainer.create_ensemble(model)
💻 Installation
pip install pl-crossvalidate
Or latest version from github
pip install https://github.com/SkafteNicki/pl_crossvalidate/archive/master.zip
Requires torch>=2.0
, lightning>=2.0
and scikit-learn>=1.0
.
🤔 Cross-validation: why?
The core functionality of machine learning algorithms is that they are able to learn from data. Therefore, it is very interesting to ask the question: how well does our algorithms actually learn?. This is in abstract question, because it requires us to define what well means. One interpretation of this question is an algorithms ability to generalize e.g. a model that generalizes well have actually learned something meaningfull.
The mathematical definition of the generalization error/expected loss/risk is given by
where is some function denotes the loss function and is the joint probability distribution between and . This is the theoretical error an algorithm will do on some unobserved dataset. The problem with this definition is that we cannot compute it, due to being unknown and even if we knew it the integral is intractable. The best we therefore can do is an approximation of the generalization error:
which measures the error that our function does on datapoints measured by loss function . This function we can compute (just think of this as your normal loss function) and we even know that
Namely that approximation of the generalization error will become the true generalization error if we just evaluate it on enough data. But how does all this related to cross-validation you may ask? The problem with the above is that is not a fixed function, but data-dependent function i.e. . Thus, the above approximation will only converge if and refers to different sets of data points. This is where cross-validation strategies comes into play.
Hold out | K-fold |
---|---|
In general we consider two viable strategies for selecting the
(validation) and
(training) set: hold-out validation
and K-fold cross validation. In hold out we create a separate independent set of data to evaluate our training on. This
is easily done in native pytorch-lightning by implementing the validation_step
method. For K-fold we cut our data
into K equally large chunks and then we iteratively train on K-1 folds and evaluate on the remaining 1 fold, repeating
this K times. In general K-fold gives a better approximation of the generalization error than hold-out, but at the
expense of requiring you to train K models.
🗒️ Some notes
-
For the
.cross_validate
method to work, we in addition to the standard set of method in lightning that need to be implemented (training_step
andconfigure_optimizers
) we also requires thetest_step
method to be implemented, as we use this method evaluating the hold out set. We do not rely on thevalidation_step
method as your models training may be dependent on the validation set (for example if you use early stopping) and your validation set will therefore not be truly separated from the training. -
To do the splitting in cross validation we need the total number of data points in your dataset. For this reason, we require that your dataset implements the
__len__
method. -
Cross validation is always done sequentially, even if the device you are training on in principal could fit parallel training on multiple folds at the same time. We try to figure out in the future if we can parallelize the process.
-
Logging can be a bit weird. Logging of training progress is essentially not important to cross-validation, but that does not mean that it is interesting to track. The cross-validation method will hijack the
version
attribute of any logger attached to the trainer and set the logging directory tof"{version}/fold_{fold_index}"
. -
Stratified splitting assume that we can extract a 1D label vector from your dataset.
-
If your dataset has an
labels
attribute, we will use that as the labels -
If the attribute does not exist, we manually iterate over your dataset trying to extract the labels when creating the splits (this is done as part of
.setup
phase of the datamodule). By default we assume that given abatch
the labels can be found as the second argument e.g.batch[1]
. You can adjust this by importing the specializedKFoldDataModule
and changing thelabel_extractor
attribute. For example, if your batches are dictionaries instead you can do something like this:from pl_crossvalidate import KFoldDataModule, KFoldTrainer model = ... trainer = KFoldTrainer(...) datamodule = KFoldDataModule( num_folds, shuffle, stratified, # these should match how the trainer is initialized train_dataloader=my_train_dataloader, ) # change the label extractor function, such that it will return the labels for a given batch datamodule.label_extractor = lambda batch: batch['y'] trainer.cross_validate(model, datamodule=datamodule)
-
😃 Bibtex
If you want to cite the framework feel free to use this:
@article{software:pl_crossvalidate,
title={PL Crossvalidate},
author={Nicki S. Detlefsen},
journal={GitHub. Note: https://github.com/SkafteNicki/pl_crossvalidate},
year={2023}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file pl_crossvalidate-0.1.0.tar.gz
.
File metadata
- Download URL: pl_crossvalidate-0.1.0.tar.gz
- Upload date:
- Size: 52.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.11.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2e28a672f5fc1e848fc6a1b649edc3d6f1429e897f95cb5b6dc59809ce740b13 |
|
MD5 | b88c792010593615f9964766a08015ea |
|
BLAKE2b-256 | 066818a1b167c8b9618c4cf946ec8255c81996e463b700dc2aa63fc85ff7bfa5 |
File details
Details for the file pl_crossvalidate-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: pl_crossvalidate-0.1.0-py3-none-any.whl
- Upload date:
- Size: 19.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.11.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ee095ded6979a4e6b8d5737c7926c0183bf9461ce99857a5c766a75cba19914f |
|
MD5 | a7f4e69d910e2b27420a77515691fd1a |
|
BLAKE2b-256 | d63cc970b3831103d54ccebf0dad70e33b2adce71dea46c34959e95572e6351d |