Skip to main content

A PyTorch utility package for Early Stopping

Project description

Early Stopping for PyTorch

Early stopping is a form of regularization used to avoid overfitting on the training dataset. Early stopping keeps track of the validation loss, if the loss stops decreasing for several epochs in a row the training stops. The EarlyStopping class in early_stopping_pytorch/early_stopping.py is used to create an object to keep track of the validation loss while training a PyTorch model. It will save a checkpoint of the model each time the validation loss decrease. We set the patience argument in the EarlyStopping class to how many epochs we want to wait after the last time the validation loss improved before breaking the training loop. There is a simple example of how to use the EarlyStopping class in the MNIST_Early_Stopping_example notebook.

Underneath is a plot from the example notebook, which shows the last checkpoint made by the EarlyStopping object, right before the model started to overfit. It had patience set to 20.

Loss Plot

Installation

Option 1: Install from PyPI (Recommended)

pip install early-stopping-pytorch

Option 2: Install from Source

For development or if you want the latest unreleased changes:

1. Clone the Repository

git clone https://github.com/your_username/early-stopping-pytorch.git
cd early-stopping-pytorch

2. Set Up the Virtual Environment

Run the setup script to create a virtual environment and install all necessary dependencies.

./setup_dev_env.sh

3. Activate the Virtual Environment

Activate the virtual environment:

source dev-venv/bin/activate

4. Install the Package in Editable Mode

Install the package locally in editable mode so you can use it immediately:

pip install -e .

Usage

from early_stopping_pytorch import EarlyStopping

# Initialize early stopping object
early_stopping = EarlyStopping(patience=7, verbose=True)

# In your training loop:
for epoch in range(num_epochs):
    # ... training code ...
    val_loss = ... # calculate validation loss

    # Early stopping call
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("Early stopping triggered")
        break

For a complete example, see the MNIST Early Stopping Example Notebook.

Citation

If you find this package useful in your research, please consider citing it as:

@misc{early_stopping_pytorch,
  author = {Bjarte Mehus Sunde},
  title = {early-stopping-pytorch: A PyTorch utility package for Early Stopping},
  year = {2024},
  url = {https://github.com/Bjarten/early-stopping-pytorch},
}

References

The EarlyStopping class in early_stopping_pytorch/early_stopping.py is inspired by the ignite EarlyStopping class.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

early_stopping_pytorch-1.0.10.tar.gz (7.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

early_stopping_pytorch-1.0.10-py3-none-any.whl (4.6 kB view details)

Uploaded Python 3

File details

Details for the file early_stopping_pytorch-1.0.10.tar.gz.

File metadata

  • Download URL: early_stopping_pytorch-1.0.10.tar.gz
  • Upload date:
  • Size: 7.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for early_stopping_pytorch-1.0.10.tar.gz
Algorithm Hash digest
SHA256 8f08302b3fb85307c386b4313cc1eb58b6e46f4dd72661baaba2c341b3d1f345
MD5 c9b8aa3e31b0e0198ad524fd9f2c9e71
BLAKE2b-256 5a498c5a73d7e8a5efc800c83e1e3fc4fe772e542f4611f50b0d5d154b261ccf

See more details on using hashes here.

Provenance

The following attestation bundles were made for early_stopping_pytorch-1.0.10.tar.gz:

Publisher: publish.yml on Bjarten/early-stopping-pytorch

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file early_stopping_pytorch-1.0.10-py3-none-any.whl.

File metadata

File hashes

Hashes for early_stopping_pytorch-1.0.10-py3-none-any.whl
Algorithm Hash digest
SHA256 6015220ec6a191d9bec1cf1a58c89e4d9ae3cef9ff74b060433585e62d18a20b
MD5 ca15e2e97fb6154cf9492153c32f4617
BLAKE2b-256 a515451a596c1bcf0d8efdebde5eb21f646bd879929939560ed471229b52fa73

See more details on using hashes here.

Provenance

The following attestation bundles were made for early_stopping_pytorch-1.0.10-py3-none-any.whl:

Publisher: publish.yml on Bjarten/early-stopping-pytorch

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page