PyTorch training utilities
Project description
PyTorch Essentials
A collection of useful utilities for PyTorch model training, evaluation, and experiment management.
Features
- Training Loop: Ready-to-use training and evaluation loops with progress tracking
- Early Stopping: Prevent overfitting with configurable early stopping callback
- Visualization: Plot training curves and metrics
- Model Management: Save/load models with metadata and results
- Configuration: YAML-based configuration management
- Utilities: Device detection, seed setting, parameter counting
- Wandb Integration: Comprehensive logging of classification metrics including precision, recall, F1-score, and confusion matrices
Installation
From PyPI (recommended)
pip install pytorch-essentials
With optional Weights & Biases support:
pip install pytorch-essentials[wandb]
For development:
pip install pytorch-essentials[dev]
From source
git clone https://github.com/yourusername/pytorch_essentials.git
cd pytorch_essentials
pip install -e .
Requirements:
- Python 3.8+
- PyTorch 2.0+
- See
pyproject.tomlfor full list of dependencies
Quick Start
Basic Usage
from pytorch_essentials import (
train, get_device, EarlyStopping,
plot_loss_curves, set_seeds
)
# Set reproducibility
set_seeds(42)
# Get device
device = get_device()
# Initialize model, optimizer, loss
model = YourModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()
# Optional: Early stopping
early_stopping = EarlyStopping(patience=5, delta=0.001)
# Train
results = train(
model=model,
train_dataloader=train_loader,
val_dataloader=val_loader,
optimizer=optimizer,
loss_fn=loss_fn,
num_epochs=100,
device=device,
config=config,
early_stopping=early_stopping
)
# Visualize results
fig = plot_loss_curves(results)
fig.savefig('training_curves.png')
With Configuration File
Create a config.yaml:
project_name: my_project
hyperparameters:
learning_rate: 1e-3
batch_size: 64
epochs: 100
flags:
use_subset: false
save_model: true
debug: false
use_wandb: false
Then in your code:
from pytorch_essentials import load_config, print_config
config = load_config('config.yaml')
print_config(config)
Examples
Check the examples/ directory for complete working examples:
examples/basic_training.py- Complete MNIST training exampleexamples/config.yaml- Example configuration file
Run the example:
cd examples
python basic_training.py
API Reference
Training Functions
train(model, train_dataloader, val_dataloader, optimizer, loss_fn, num_epochs, device, config, early_stopping=None, scheduler=None)
Main training loop with validation.
Returns: Dictionary with training history including:
train_loss: List of training lossestrain_acc: List of training accuraciestest_loss: List of validation lossestest_acc: List of validation accuraciesbest_epoch: Best epoch number (if early stopping used)
evaluate_model(model, test_dataloader, loss_fn, device, class_names, log_to_wandb=False)
Evaluate model on test set with confusion matrix and metrics.
Args:
log_to_wandb: If True, log all metrics to Weights & Biases including:- Test loss and accuracy
- Precision, recall, F1-score (macro and weighted)
- Per-class metrics
- Confusion matrix visualization
Callbacks
EarlyStopping(patience=5, delta=0.01, verbose=True)
Early stopping to prevent overfitting.
Args:
patience: Number of epochs to wait for improvementdelta: Minimum change to qualify as improvementverbose: Print status messages
Utilities
get_device()
Return best available device (CUDA > MPS > CPU).
save_model(model, fig, results, save_path)
Save model weights, plots, and training results.
count_parameters(model)
Count and display model parameters.
set_seeds(seed=42)
Set random seeds for reproducibility.
load_config(config_path)
Load YAML configuration file.
Visualization
plot_loss_curves(results)
Plot training and validation loss/accuracy curves.
Returns: Matplotlib figure
print_train_time(start, end, device=None)
Print elapsed training time.
Project Structure
pytorch_essentials/
├── pytorch_essentials/ # Main package
│ ├── __init__.py # Package exports
│ ├── engine.py # Training/evaluation loops
│ ├── callbacks.py # Early stopping callback
│ ├── utils.py # Utility functions
│ └── visualization.py # Plotting functions
├── examples/ # Example scripts
│ ├── basic_training.py # MNIST example
│ └── config.yaml # Example config
├── requirements.txt # Dependencies
├── README.md # This file
└── LICENSE # License file
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pytorch_essentials-0.1.1.tar.gz.
File metadata
- Download URL: pytorch_essentials-0.1.1.tar.gz
- Upload date:
- Size: 18.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
047d6006f2da9b609a4b91d258f42ba69b060e1cf7f1efb5e4abf328e05102df
|
|
| MD5 |
33b2c30992ef4895b2e0296c6a638036
|
|
| BLAKE2b-256 |
6c0ed5cbc4d42f2ea1c6ae05a7469588d00082a2fd70c6c4ee42dfa2d6b5c795
|
File details
Details for the file pytorch_essentials-0.1.1-py3-none-any.whl.
File metadata
- Download URL: pytorch_essentials-0.1.1-py3-none-any.whl
- Upload date:
- Size: 12.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7b96d06b93c396a9a66769cf38f7ced1b1b838823fa05eb4a4717f69475957fd
|
|
| MD5 |
847447aa5399071b7cf8759cb91736b8
|
|
| BLAKE2b-256 |
f066b45de0e5b14caf191cb66ab0a703f38dc6f681312e91b0ae7c39e25ac90c
|