A flexible PyTorch training engine with interface-based design
Project description
PyTorch TrainEngine
A flexible and modular PyTorch training engine with interface-based design.
Features
- Interface-based Architecture: Define trainable models through the
ITrainableinterface contract - Automatic Hardware Detection: Automatically detects and uses CUDA if available
- Lazy Optimizer Initialization: Optimizers are initialized only when training starts
- Standard Training Pipeline: Implements the canonical PyTorch training loop
- Easy Integration: Simple integration with custom models and datasets
Installation
pip install pytorch-TrainEngine
Quick Start
import torch
from pytorch_trainengine import ITrainable, TrainingEngine
# Define your model
class MyModel(ITrainable):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 1)
def forward_pass(self, x):
return self.linear(x)
def parameters(self):
return self.linear.parameters()
# Create training engine
criterion = torch.nn.MSELoss()
engine = TrainingEngine(criterion=criterion, lr=0.01)
# Train your model
model = MyModel()
x = torch.randn(32, 10)
y = torch.randn(32, 1)
loss = engine.train_step(model, x, y)
print(f"Loss: {loss}")
API Reference
ITrainable
Abstract base class for trainable models.
Methods:
forward_pass(x): Forward pass logicparameters(): Return model parameters for optimizer
TrainingEngine
Main training engine class.
Parameters:
criterion: Loss functionoptimizer_cls: Optimizer class (default:torch.optim.SGD)lr: Learning rate (default: 0.01)device: Computation device (auto-detected if None)
Methods:
setup_optimizer(model): Initialize optimizer (called automatically)train_step(model, x, y): Single training steptrain_epoch(model, dataloader): Train one epoch
Requirements
- Python >= 3.8
- torch >= 1.9.0
License
MIT
Author
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pytorch_trainengine-0.1.0.tar.gz.
File metadata
- Download URL: pytorch_trainengine-0.1.0.tar.gz
- Upload date:
- Size: 4.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
67f73f876e54751c4c955d9a0e708abd4ca0b6bd6136319b8223d36d9e973ca1
|
|
| MD5 |
37425354e61f6d11d31cec9c0374168b
|
|
| BLAKE2b-256 |
20d2db899aa42c133e52c0b9f599a53a0e00a7e64bd276620972cf9dbc43828a
|
File details
Details for the file pytorch_trainengine-0.1.0-py3-none-any.whl.
File metadata
- Download URL: pytorch_trainengine-0.1.0-py3-none-any.whl
- Upload date:
- Size: 4.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b76f41d30690772eedd83e6734c12c8b3823a4ca9f0511e698301a4f0ffbd19b
|
|
| MD5 |
a67773e5462004886a8b757ef4396be7
|
|
| BLAKE2b-256 |
1caff24b701c442c3b08424598b1df0b66880d33aa65a3cc42af8f15af8c8f47
|