minimalistic convenience utilities for using PyTorch
Project description
ConvinenceTorch - A Lightweight and Minimalistic Convenience Library for PyTorch
ConvinenceTorch is a Python package that serves as a lightweight and minimalistic convenience library for PyTorch, a popular deep learning framework. It provides a collection of utility modules and classes to simplify and streamline the training process of neural networks in PyTorch.
Features
-
NNtrainer: ConvinenceTorch offers a flexible and extensible
NNtrainer
class designed for training neural networks in PyTorch. TheNNtrainer
handles various common training tasks, such as model initialization, optimization, loss computation, device management, and training/validation loops. By utilizing theNNtrainer
class, developers can focus on defining their models and customizing the training process without being burdened by boilerplate code. -
Memory Usage Report: ConvinenceTorch provides a memory usage report feature that allows users to monitor the memory consumption during the training process. The mem_report() function in NNtrainer provides valuable insights into the memory usage, helping users optimize their models and data loaders efficiently.
-
Metrics and Loss capture: ConvinenceTorch offers a comprehensive and flexible mechanism for capturing and tracking metrics and losses during the training process. The NNtrainer class allows users to easily specify the metrics they wish to monitor, such as accuracy, F1 score, or any custom evaluation metric. Additionally, users can choose to record the training and validation losses for further analysis.
-
Stochastic Weight Averaging (SWA) Support: NNtrainer supports Stochastic Weight Averaging, a technique that can improve model generalization. The class includes functionalities for setting SWA models, starting epochs, and SWA-specific validation.
-
Model Checkpointing: ConvinenceTorch includes model checkpointing functionality, allowing users to save the model's state at specified intervals during training or after every epoch. This feature facilitates easy resumption of training from the last saved checkpoint, enhancing training reliability and fault tolerance.
-
Inference Support : With NNtrainer, users can make predictions using the trained model on new data, enabling them to leverage their models for inference tasks.
Installation
This library is designed with a minimalistic approach to cater to educational applications and general research purposes. It boasts well-documented docstrings for every component, ensuring clarity and ease of understanding throughout its usage.
For those seeking to extend its functionalities, performing an editable install and making necessary tweaks to the library is highly recommended. Consequently, an editable install is the preferred method, allowing users to seamlessly customize the library to suit their specific requirements and leverage its full potential.
$ pip install convinence-torch
Usage
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
# Enable debugging
import os
os.environ['LOG'] = '10'
# Import trainer
from torchutils.trainer import NNtrainer
# Setup your data loaders
train_data = torchvision.datasets.FashionMNIST('./data/', train=True, transform=torchvision.transforms.ToTensor(), download=True)
test_data = torchvision.datasets.FashionMNIST('./data/', train=False, transform=torchvision.transforms.ToTensor(), download=True)
trainloader = DataLoader(train_data, batch_size=B, shuffle=True, num_workers=5)
test_loader = DataLoader(test_data, batch_size=B)
# Constants
D = 28 * 28
n = 32
C = 1
classes = 10
eta_0 = 0.005
eta_1 = 0.001
# Define your PyTorch model, optimizer, and loss function
fc_model = nn.Sequential(
nn.Conv2d(1, n, 3, padding=1), # 28 * 28
nn.LeakyReLU(0.1),
nn.Conv2d(n, n, 3, padding=1), # 28 * 28
nn.LeakyReLU(0.1),
nn.MaxPool2d(2), # 14 * 14
nn.Conv2d(n, n // 2, 3, padding=1),
nn.LeakyReLU(0.1),
nn.Conv2d(n // 2, n // 2, 3, padding=1),
nn.LeakyReLU(0.1),
nn.MaxPool2d(2), # 7 * 7
nn.Conv2d(n // 2, n // 4, 3, padding=1),
nn.LeakyReLU(0.1),
nn.Flatten(),
nn.Linear(7 * 7 * n // 4, 64),
nn.Tanh(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, 10)
)
# Create an instance of NNtrainer
trainer = NNtrainer(model, optimizer, loss)
# See the memory usage report
trainer.mem_report(batch_shape=(33, 1, 28, 28))
# Add Cosine Annealing
trainer.add_cosineannealing_lrs(lr_initial=eta_0, lr_final=eta_1, dips=5, epoch=20, verbose=True)
# Add Gradient Clipping
trainer.add_gradient_clipping(min=-100, max=100)
# Compile the model before training
trainer.compile()
# Train the model
trainer.train(trainloader=trainloader, valloader=valloader, epoch=20, metrics=['accuracy', 'f1'], record_loss=True, checkpoint_file='train')
# Plot the training and validation metric curves
trainer.plot_train_validation_metric_curve('accuracy')
Contributing
Contributions are welcome! If you find any issues or have suggestions for new features, please open an issue or submit a pull request.
License
This project is licensed under the MIT License.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file convinence-torch-1.0.0.tar.gz
.
File metadata
- Download URL: convinence-torch-1.0.0.tar.gz
- Upload date:
- Size: 969.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e9af473d47a0bda635290b84d51b54657b7b9c189bf170c5d21cf601c94ccbe3 |
|
MD5 | b900b0271c8c65a0179d3d7a3686c873 |
|
BLAKE2b-256 | a4cbea36b7ce96e2a6785a5a0e9e63f730a3b7142ccf47c7c3c09bddc0fd1f99 |
File details
Details for the file convinence_torch-1.0.0-py3-none-any.whl
.
File metadata
- Download URL: convinence_torch-1.0.0-py3-none-any.whl
- Upload date:
- Size: 27.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | c96984bf721a080c419f8d6f736c9c76e6ca55bb3ab57ab82ece9734a9f5345e |
|
MD5 | aac0c266aa1684c573d2f91886e7269a |
|
BLAKE2b-256 | a59e25646a8abc1074d2a2b2b126c81ca1bb092d2e0f999b4ae11b29ebfa5169 |