toolkit for training pytorch models
Project description
🚀 ML Trainer Package
A flexible and powerful PyTorch training framework with built-in logging, metrics tracking, and early stopping capabilities!
📦 Key Components
- Trainer: Main training loop with validation and reporting
- TrainerSettings: Configuration management for training parameters
- Models: Collection of CNN and RNN architectures
- Metrics: Customizable evaluation metrics
- Preprocessors: Data preparation utilities
🛠️ Installation
Use uv, or if you want to use the 10-100x slower pip, i wont stop you.
uv add mltrainer # recommended
pip install mltrainer # i cant stop you
🎯 Quick Start
Here's a simple example using a CNN model with MNIST:
from trainer import Trainer, TrainerSettings
from imagemodels import CNN
from metrics import Accuracy
from preprocessors import BasePreprocessor
from settings import ReportTypes
from pathlib import Path
# Define training settings
settings = TrainerSettings(
epochs=10,
metrics=[Accuracy()],
logdir=Path("./logs"),
train_steps=100,
valid_steps=20,
reporttypes=[ReportTypes.TENSORBOARD, ReportTypes.TOML],
optimizer_kwargs={"lr": 0.001},
scheduler_kwargs={"factor": 0.1, "patience": 5},
earlystop_kwargs={"patience": 7, "save": True}
)
# Initialize model and trainer
model = CNN(num_classes=10, kernel_size=3, filter1=32, filter2=64)
trainer = Trainer(
model=model,
settings=settings,
loss_fn=nn.CrossEntropyLoss(),
optimizer=torch.optim.Adam,
traindataloader=train_loader, # Your DataLoader
validdataloader=valid_loader, # Your DataLoader
scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau,
device="cuda" if torch.cuda.is_available() else "cpu"
)
# Start training
trainer.loop()
📊 Report Types
The package supports multiple reporting backends:
- 📈 TENSORBOARD: Real-time training visualization
- 📝 TOML: Configuration and model architecture serialization. See https://pypi.org/project/tomlserializer/ for details
- 📊 MLFLOW: Experiment tracking and model management
- 🔄 RAY: Distributed training support
Configure them in TrainerSettings:
settings = TrainerSettings(
reporttypes=[ReportTypes.TENSORBOARD, ReportTypes.MLFLOW],
# ... other settings
)
🔍 Metrics
Built-in metrics include:
- Accuracy: Classification accuracy
- MAE: Mean Absolute Error
- MASE: Mean Absolute Scaled Error (for time series)
Metrics are PyTorch-native and handle device placement automatically:
from metrics import Accuracy, MAE
settings = TrainerSettings(
metrics=[Accuracy(), MAE()],
# ... other settings
)
🔄 Preprocessors
Two main preprocessors are available:
-
BasePreprocessor: Standard batch processing for fixed-size inputs
preprocessor = BasePreprocessor() batch_x, batch_y = preprocessor(batch)
-
PaddedPreprocessor: Handles variable-length sequences with padding
preprocessor = PaddedPreprocessor() padded_x, batch_y = preprocessor(sequence_batch)
🧠 Available Models
The package includes several model architectures:
Image Models
- CNN with configurable filters
- Neural Network with customizable layers
RNN Models
- Base RNN
- GRU with optional attention
- NLP models with embedding support
Example using AttentionGRU:
config = {
"input_size": 10,
"hidden_size": 64,
"output_size": 1,
"num_layers": 2,
"dropout": 0.1
}
model = AttentionGRU(config)
⚙️ Advanced Configuration
TrainerSettings supports comprehensive training configuration:
settings = TrainerSettings(
epochs=100,
metrics=[Accuracy()],
logdir=Path("./experiments"),
train_steps=500,
valid_steps=50,
reporttypes=[ReportTypes.TENSORBOARD, ReportTypes.MLFLOW],
optimizer_kwargs={
"lr": 1e-3,
"weight_decay": 1e-5
},
scheduler_kwargs={
"factor": 0.1,
"patience": 10
},
earlystop_kwargs={
"save": True,
"verbose": True,
"patience": 10
}
)
🔔 Early Stopping
The trainer includes built-in early stopping with model checkpointing:
settings = TrainerSettings(
earlystop_kwargs={
"patience": 7, # Episodes to wait before stopping
"save": True, # Save best model
"verbose": True, # Print progress
"delta": 0.001 # Minimum improvement threshold
},
# ... other settings
)
📝 Logging
The package uses loguru for comprehensive logging. All training progress, early stopping events, and potential issues are automatically logged:
from loguru import logger
# Logs are automatically created in your logdir
# Example log message:
# [2024-02-13 14:30:22] INFO: Epoch 5 train 0.3421 test 0.2891 metric [0.8934]
🤝 Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mltrainer-0.2.6.tar.gz.
File metadata
- Download URL: mltrainer-0.2.6.tar.gz
- Upload date:
- Size: 94.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.6.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
09e7d59d533f3b1f7e9e0fd9f48b91f3edf786724073677bebbcd292d8a86c62
|
|
| MD5 |
00c84fe2723cbd88935da7ef7f213919
|
|
| BLAKE2b-256 |
1d6f0ae0dcbb90477be3e74ddcadb31b9b3386adff9c601025d8650a00652091
|
File details
Details for the file mltrainer-0.2.6-py3-none-any.whl.
File metadata
- Download URL: mltrainer-0.2.6-py3-none-any.whl
- Upload date:
- Size: 16.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.6.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
052ab2c7f7409365145e0daead20e3a66cee3fda490fa24409e7792d06287f5b
|
|
| MD5 |
d0b3828290daa618ddb381db845b9d62
|
|
| BLAKE2b-256 |
244a25b74222b068e6f068db7173a33dcb07625d0b98db7c04bfe82660104ee6
|