Skip to main content

toolkit for training pytorch models

Project description

🚀 ML Trainer Package

Ruff Code style: black PyPi version License: MIT

A flexible and powerful PyTorch training framework with built-in logging, metrics tracking, and early stopping capabilities!

📦 Key Components

  • Trainer: Main training loop with validation and reporting
  • TrainerSettings: Configuration management for training parameters
  • Models: Collection of CNN and RNN architectures
  • Metrics: Customizable evaluation metrics
  • Preprocessors: Data preparation utilities

🛠️ Installation

Use uv, or if you want to use the 10-100x slower pip, i wont stop you.

uv add mltrainer # recommended
pip install mltrainer # i cant stop you

🎯 Quick Start

Here's a simple example using a CNN model with MNIST:

from trainer import Trainer, TrainerSettings
from imagemodels import CNN
from metrics import Accuracy
from preprocessors import BasePreprocessor
from settings import ReportTypes
from pathlib import Path

# Define training settings
settings = TrainerSettings(
    epochs=10,
    metrics=[Accuracy()],
    logdir=Path("./logs"),
    train_steps=100,
    valid_steps=20,
    reporttypes=[ReportTypes.TENSORBOARD, ReportTypes.TOML],
    optimizer_kwargs={"lr": 0.001},
    scheduler_kwargs={"factor": 0.1, "patience": 5},
    earlystop_kwargs={"patience": 7, "save": True}
)

# Initialize model and trainer
model = CNN(num_classes=10, kernel_size=3, filter1=32, filter2=64)
trainer = Trainer(
    model=model,
    settings=settings,
    loss_fn=nn.CrossEntropyLoss(),
    optimizer=torch.optim.Adam,
    traindataloader=train_loader,  # Your DataLoader
    validdataloader=valid_loader,  # Your DataLoader
    scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

# Start training
trainer.loop()

📊 Report Types

The package supports multiple reporting backends:

  • 📈 TENSORBOARD: Real-time training visualization
  • 📝 TOML: Configuration and model architecture serialization. See https://pypi.org/project/tomlserializer/ for details
  • 📊 MLFLOW: Experiment tracking and model management
  • 🔄 RAY: Distributed training support

Configure them in TrainerSettings:

settings = TrainerSettings(
    reporttypes=[ReportTypes.TENSORBOARD, ReportTypes.MLFLOW],
    # ... other settings
)

🔍 Metrics

Built-in metrics include:

  • Accuracy: Classification accuracy
  • MAE: Mean Absolute Error
  • MASE: Mean Absolute Scaled Error (for time series)

Metrics are PyTorch-native and handle device placement automatically:

from metrics import Accuracy, MAE

settings = TrainerSettings(
    metrics=[Accuracy(), MAE()],
    # ... other settings
)

🔄 Preprocessors

Two main preprocessors are available:

  1. BasePreprocessor: Standard batch processing for fixed-size inputs

    preprocessor = BasePreprocessor()
    batch_x, batch_y = preprocessor(batch)
    
  2. PaddedPreprocessor: Handles variable-length sequences with padding

    preprocessor = PaddedPreprocessor()
    padded_x, batch_y = preprocessor(sequence_batch)
    

🧠 Available Models

The package includes several model architectures:

Image Models

  • CNN with configurable filters
  • Neural Network with customizable layers

RNN Models

  • Base RNN
  • GRU with optional attention
  • NLP models with embedding support

Example using AttentionGRU:

config = {
    "input_size": 10,
    "hidden_size": 64,
    "output_size": 1,
    "num_layers": 2,
    "dropout": 0.1
}
model = AttentionGRU(config)

⚙️ Advanced Configuration

TrainerSettings supports comprehensive training configuration:

settings = TrainerSettings(
    epochs=100,
    metrics=[Accuracy()],
    logdir=Path("./experiments"),
    train_steps=500,
    valid_steps=50,
    reporttypes=[ReportTypes.TENSORBOARD, ReportTypes.MLFLOW],
    optimizer_kwargs={
        "lr": 1e-3,
        "weight_decay": 1e-5
    },
    scheduler_kwargs={
        "factor": 0.1,
        "patience": 10
    },
    earlystop_kwargs={
        "save": True,
        "verbose": True,
        "patience": 10
    }
)

🔔 Early Stopping

The trainer includes built-in early stopping with model checkpointing:

settings = TrainerSettings(
    earlystop_kwargs={
        "patience": 7,      # Episodes to wait before stopping
        "save": True,       # Save best model
        "verbose": True,    # Print progress
        "delta": 0.001     # Minimum improvement threshold
    },
    # ... other settings
)

📝 Logging

The package uses loguru for comprehensive logging. All training progress, early stopping events, and potential issues are automatically logged:

from loguru import logger

# Logs are automatically created in your logdir
# Example log message:
# [2024-02-13 14:30:22] INFO: Epoch 5 train 0.3421 test 0.2891 metric [0.8934]

🤝 Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mltrainer-0.2.6.tar.gz (94.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mltrainer-0.2.6-py3-none-any.whl (16.4 kB view details)

Uploaded Python 3

File details

Details for the file mltrainer-0.2.6.tar.gz.

File metadata

  • Download URL: mltrainer-0.2.6.tar.gz
  • Upload date:
  • Size: 94.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.6.8

File hashes

Hashes for mltrainer-0.2.6.tar.gz
Algorithm Hash digest
SHA256 09e7d59d533f3b1f7e9e0fd9f48b91f3edf786724073677bebbcd292d8a86c62
MD5 00c84fe2723cbd88935da7ef7f213919
BLAKE2b-256 1d6f0ae0dcbb90477be3e74ddcadb31b9b3386adff9c601025d8650a00652091

See more details on using hashes here.

File details

Details for the file mltrainer-0.2.6-py3-none-any.whl.

File metadata

  • Download URL: mltrainer-0.2.6-py3-none-any.whl
  • Upload date:
  • Size: 16.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.6.8

File hashes

Hashes for mltrainer-0.2.6-py3-none-any.whl
Algorithm Hash digest
SHA256 052ab2c7f7409365145e0daead20e3a66cee3fda490fa24409e7792d06287f5b
MD5 d0b3828290daa618ddb381db845b9d62
BLAKE2b-256 244a25b74222b068e6f068db7173a33dcb07625d0b98db7c04bfe82660104ee6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page