Skip to main content

PyTorch training utilities

Project description

PyTorch Essentials

A collection of useful utilities for PyTorch model training, evaluation, and experiment management.

Features

  • Training Loop: Ready-to-use training and evaluation loops with progress tracking
  • Early Stopping: Prevent overfitting with configurable early stopping callback
  • Visualization: Plot training curves and metrics
  • Model Management: Save/load models with metadata and results
  • Configuration: YAML-based configuration management
  • Utilities: Device detection, seed setting, parameter counting
  • Wandb Integration: Comprehensive logging of classification metrics including precision, recall, F1-score, and confusion matrices

Installation

From PyPI (recommended)

pip install pytorch-essentials

With optional Weights & Biases support:

pip install pytorch-essentials[wandb]

For development:

pip install pytorch-essentials[dev]

From source

git clone https://github.com/yourusername/pytorch_essentials.git
cd pytorch_essentials
pip install -e .

Requirements:

  • Python 3.8+
  • PyTorch 2.0+
  • See pyproject.toml for full list of dependencies

Quick Start

Basic Usage

from pytorch_essentials import (
    train, get_device, EarlyStopping,
    plot_loss_curves, set_seeds
)

# Set reproducibility
set_seeds(42)

# Get device
device = get_device()

# Initialize model, optimizer, loss
model = YourModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

# Optional: Early stopping
early_stopping = EarlyStopping(patience=5, delta=0.001)

# Train
results = train(
    model=model,
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    optimizer=optimizer,
    loss_fn=loss_fn,
    num_epochs=100,
    device=device,
    config=config,
    early_stopping=early_stopping
)

# Visualize results
fig = plot_loss_curves(results)
fig.savefig('training_curves.png')

With Configuration File

Create a config.yaml:

project_name: my_project
hyperparameters:
  learning_rate: 1e-3
  batch_size: 64
  epochs: 100
flags:
  use_subset: false
  save_model: true
  debug: false
  use_wandb: false

Then in your code:

from pytorch_essentials import load_config, print_config

config = load_config('config.yaml')
print_config(config)

Examples

Check the examples/ directory for complete working examples:

  • examples/basic_training.py - Complete MNIST training example
  • examples/config.yaml - Example configuration file

Run the example:

cd examples
python basic_training.py

API Reference

Training Functions

train(model, train_dataloader, val_dataloader, optimizer, loss_fn, num_epochs, device, config, early_stopping=None, scheduler=None)

Main training loop with validation.

Returns: Dictionary with training history including:

  • train_loss: List of training losses
  • train_acc: List of training accuracies
  • test_loss: List of validation losses
  • test_acc: List of validation accuracies
  • best_epoch: Best epoch number (if early stopping used)

evaluate_model(model, test_dataloader, loss_fn, device, class_names, log_to_wandb=False)

Evaluate model on test set with confusion matrix and metrics.

Args:

  • log_to_wandb: If True, log all metrics to Weights & Biases including:
    • Test loss and accuracy
    • Precision, recall, F1-score (macro and weighted)
    • Per-class metrics
    • Confusion matrix visualization

Callbacks

EarlyStopping(patience=5, delta=0.01, verbose=True)

Early stopping to prevent overfitting.

Args:

  • patience: Number of epochs to wait for improvement
  • delta: Minimum change to qualify as improvement
  • verbose: Print status messages

Utilities

get_device()

Return best available device (CUDA > MPS > CPU).

save_model(model, fig, results, save_path)

Save model weights, plots, and training results.

count_parameters(model)

Count and display model parameters.

set_seeds(seed=42)

Set random seeds for reproducibility.

load_config(config_path)

Load YAML configuration file.

Visualization

plot_loss_curves(results)

Plot training and validation loss/accuracy curves.

Returns: Matplotlib figure

print_train_time(start, end, device=None)

Print elapsed training time.

Project Structure

pytorch_essentials/
├── pytorch_essentials/         # Main package
│   ├── __init__.py            # Package exports
│   ├── engine.py              # Training/evaluation loops
│   ├── callbacks.py           # Early stopping callback
│   ├── utils.py               # Utility functions
│   └── visualization.py       # Plotting functions
├── examples/                   # Example scripts
│   ├── basic_training.py      # MNIST example
│   └── config.yaml            # Example config
├── requirements.txt           # Dependencies
├── README.md                  # This file
└── LICENSE                    # License file

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_essentials-0.1.1.tar.gz (18.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_essentials-0.1.1-py3-none-any.whl (12.3 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_essentials-0.1.1.tar.gz.

File metadata

  • Download URL: pytorch_essentials-0.1.1.tar.gz
  • Upload date:
  • Size: 18.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.3

File hashes

Hashes for pytorch_essentials-0.1.1.tar.gz
Algorithm Hash digest
SHA256 047d6006f2da9b609a4b91d258f42ba69b060e1cf7f1efb5e4abf328e05102df
MD5 33b2c30992ef4895b2e0296c6a638036
BLAKE2b-256 6c0ed5cbc4d42f2ea1c6ae05a7469588d00082a2fd70c6c4ee42dfa2d6b5c795

See more details on using hashes here.

File details

Details for the file pytorch_essentials-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_essentials-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 7b96d06b93c396a9a66769cf38f7ced1b1b838823fa05eb4a4717f69475957fd
MD5 847447aa5399071b7cf8759cb91736b8
BLAKE2b-256 f066b45de0e5b14caf191cb66ab0a703f38dc6f681312e91b0ae7c39e25ac90c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page