Skip to main content

toolkit for training pytorch models

Project description

🚀 ML Trainer Package

Ruff Code style: black PyPi version License: MIT

A flexible and powerful PyTorch training framework with built-in logging, metrics tracking, and early stopping capabilities!

📦 Key Components

  • Trainer: Main training loop with validation and reporting
  • TrainerSettings: Configuration management for training parameters
  • Models: Collection of CNN and RNN architectures
  • Metrics: Customizable evaluation metrics
  • Preprocessors: Data preparation utilities

🛠️ Installation

Use uv, or if you want to use the 10-100x slower pip, i wont stop you.

uv add mltrainer # recommended
pip install mltrainer # i cant stop you

🎯 Quick Start

Here's a simple example using a CNN model with MNIST:

from trainer import Trainer, TrainerSettings
from imagemodels import CNN
from metrics import Accuracy
from preprocessors import BasePreprocessor
from settings import ReportTypes
from pathlib import Path

# Define training settings
settings = TrainerSettings(
    epochs=10,
    metrics=[Accuracy()],
    logdir=Path("./logs"),
    train_steps=100,
    valid_steps=20,
    reporttypes=[ReportTypes.TENSORBOARD, ReportTypes.TOML],
    optimizer_kwargs={"lr": 0.001},
    scheduler_kwargs={"factor": 0.1, "patience": 5},
    earlystop_kwargs={"patience": 7, "save": True}
)

# Initialize model and trainer
model = CNN(num_classes=10, kernel_size=3, filter1=32, filter2=64)
trainer = Trainer(
    model=model,
    settings=settings,
    loss_fn=nn.CrossEntropyLoss(),
    optimizer=torch.optim.Adam,
    traindataloader=train_loader,  # Your DataLoader
    validdataloader=valid_loader,  # Your DataLoader
    scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

# Start training
trainer.loop()

📊 Report Types

The package supports multiple reporting backends:

  • 📈 TENSORBOARD: Real-time training visualization
  • 📝 TOML: Configuration and model architecture serialization. See https://pypi.org/project/tomlserializer/ for details
  • 📊 MLFLOW: Experiment tracking and model management
  • 🔄 RAY: Distributed training support

Configure them in TrainerSettings:

settings = TrainerSettings(
    reporttypes=[ReportTypes.TENSORBOARD, ReportTypes.MLFLOW],
    # ... other settings
)

🔍 Metrics

Built-in metrics include:

  • Accuracy: Classification accuracy
  • MAE: Mean Absolute Error
  • MASE: Mean Absolute Scaled Error (for time series)

Metrics are PyTorch-native and handle device placement automatically:

from metrics import Accuracy, MAE

settings = TrainerSettings(
    metrics=[Accuracy(), MAE()],
    # ... other settings
)

🔄 Preprocessors

Two main preprocessors are available:

  1. BasePreprocessor: Standard batch processing for fixed-size inputs

    preprocessor = BasePreprocessor()
    batch_x, batch_y = preprocessor(batch)
    
  2. PaddedPreprocessor: Handles variable-length sequences with padding

    preprocessor = PaddedPreprocessor()
    padded_x, batch_y = preprocessor(sequence_batch)
    

🧠 Available Models

The package includes several model architectures:

Image Models

  • CNN with configurable filters
  • Neural Network with customizable layers

RNN Models

  • Base RNN
  • GRU with optional attention
  • NLP models with embedding support

Example using AttentionGRU:

config = {
    "input_size": 10,
    "hidden_size": 64,
    "output_size": 1,
    "num_layers": 2,
    "dropout": 0.1
}
model = AttentionGRU(config)

⚙️ Advanced Configuration

TrainerSettings supports comprehensive training configuration:

settings = TrainerSettings(
    epochs=100,
    metrics=[Accuracy()],
    logdir=Path("./experiments"),
    train_steps=500,
    valid_steps=50,
    reporttypes=[ReportTypes.TENSORBOARD, ReportTypes.MLFLOW],
    optimizer_kwargs={
        "lr": 1e-3,
        "weight_decay": 1e-5
    },
    scheduler_kwargs={
        "factor": 0.1,
        "patience": 10
    },
    earlystop_kwargs={
        "save": True,
        "verbose": True,
        "patience": 10
    }
)

🔔 Early Stopping

The trainer includes built-in early stopping with model checkpointing:

settings = TrainerSettings(
    earlystop_kwargs={
        "patience": 7,      # Episodes to wait before stopping
        "save": True,       # Save best model
        "verbose": True,    # Print progress
        "delta": 0.001     # Minimum improvement threshold
    },
    # ... other settings
)

📝 Logging

The package uses loguru for comprehensive logging. All training progress, early stopping events, and potential issues are automatically logged:

from loguru import logger

# Logs are automatically created in your logdir
# Example log message:
# [2024-02-13 14:30:22] INFO: Epoch 5 train 0.3421 test 0.2891 metric [0.8934]

🤝 Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mltrainer-0.2.4-py3-none-any.whl (13.3 kB view details)

Uploaded Python 3

File details

Details for the file mltrainer-0.2.4-py3-none-any.whl.

File metadata

  • Download URL: mltrainer-0.2.4-py3-none-any.whl
  • Upload date:
  • Size: 13.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.6.8

File hashes

Hashes for mltrainer-0.2.4-py3-none-any.whl
Algorithm Hash digest
SHA256 b8517788b506f6ec92656c318982a308cb72d9505c3e8d187474965d65ad03f6
MD5 247ba8edea621d30f1a84514e2fae353
BLAKE2b-256 76b6427fb72612f4bbb4878a324935fa7f54a4e599e038b7fa54bd2a77a855ee

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page