Skip to main content

A flexible PyTorch training engine with interface-based design

Project description

PyTorch TrainEngine

A flexible and modular PyTorch training engine with interface-based design.

Features

  • Interface-based Architecture: Define trainable models through the ITrainable interface contract
  • Automatic Hardware Detection: Automatically detects and uses CUDA if available
  • Lazy Optimizer Initialization: Optimizers are initialized only when training starts
  • Standard Training Pipeline: Implements the canonical PyTorch training loop
  • Easy Integration: Simple integration with custom models and datasets

Installation

pip install pytorch-TrainEngine

Quick Start

import torch
from pytorch_trainengine import ITrainable, TrainingEngine

# Define your model
class MyModel(ITrainable):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(10, 1)
    
    def forward_pass(self, x):
        return self.linear(x)
    
    def parameters(self):
        return self.linear.parameters()

# Create training engine
criterion = torch.nn.MSELoss()
engine = TrainingEngine(criterion=criterion, lr=0.01)

# Train your model
model = MyModel()
x = torch.randn(32, 10)
y = torch.randn(32, 1)

loss = engine.train_step(model, x, y)
print(f"Loss: {loss}")

API Reference

ITrainable

Abstract base class for trainable models.

Methods:

  • forward_pass(x): Forward pass logic
  • parameters(): Return model parameters for optimizer

TrainingEngine

Main training engine class.

Parameters:

  • criterion: Loss function
  • optimizer_cls: Optimizer class (default: torch.optim.SGD)
  • lr: Learning rate (default: 0.01)
  • device: Computation device (auto-detected if None)

Methods:

  • setup_optimizer(model): Initialize optimizer (called automatically)
  • train_step(model, x, y): Single training step
  • train_epoch(model, dataloader): Train one epoch

Requirements

  • Python >= 3.8
  • torch >= 1.9.0

License

MIT

Author

15330613622@163.com

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_trainengine-0.1.0.tar.gz (4.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_trainengine-0.1.0-py3-none-any.whl (4.5 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_trainengine-0.1.0.tar.gz.

File metadata

  • Download URL: pytorch_trainengine-0.1.0.tar.gz
  • Upload date:
  • Size: 4.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.7

File hashes

Hashes for pytorch_trainengine-0.1.0.tar.gz
Algorithm Hash digest
SHA256 67f73f876e54751c4c955d9a0e708abd4ca0b6bd6136319b8223d36d9e973ca1
MD5 37425354e61f6d11d31cec9c0374168b
BLAKE2b-256 20d2db899aa42c133e52c0b9f599a53a0e00a7e64bd276620972cf9dbc43828a

See more details on using hashes here.

File details

Details for the file pytorch_trainengine-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_trainengine-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b76f41d30690772eedd83e6734c12c8b3823a4ca9f0511e698301a4f0ffbd19b
MD5 a67773e5462004886a8b757ef4396be7
BLAKE2b-256 1caff24b701c442c3b08424598b1df0b66880d33aa65a3cc42af8f15af8c8f47

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page