Skip to main content

toolkit for training pytorch models

Project description

🚀 ML Trainer Package

Ruff Code style: black PyPi version License: MIT

A flexible and powerful PyTorch training framework with built-in logging, metrics tracking, and early stopping capabilities!

📦 Key Components

  • Trainer: Main training loop with validation and reporting
  • TrainerSettings: Configuration management for training parameters
  • Models: Collection of CNN and RNN architectures
  • Metrics: Customizable evaluation metrics
  • Preprocessors: Data preparation utilities

🛠️ Installation

Use uv, or if you want to use the 10-100x slower pip, i wont stop you.

uv add mltrainer # recommended
pip install mltrainer # i cant stop you

🎯 Quick Start

Here's a simple example using a CNN model with MNIST:

from trainer import Trainer, TrainerSettings
from imagemodels import CNN
from metrics import Accuracy
from preprocessors import BasePreprocessor
from settings import ReportTypes
from pathlib import Path

# Define training settings
settings = TrainerSettings(
    epochs=10,
    metrics=[Accuracy()],
    logdir=Path("./logs"),
    train_steps=100,
    valid_steps=20,
    reporttypes=[ReportTypes.TENSORBOARD, ReportTypes.TOML],
    optimizer_kwargs={"lr": 0.001},
    scheduler_kwargs={"factor": 0.1, "patience": 5},
    earlystop_kwargs={"patience": 7, "save": True}
)

# Initialize model and trainer
model = CNN(num_classes=10, kernel_size=3, filter1=32, filter2=64)
trainer = Trainer(
    model=model,
    settings=settings,
    loss_fn=nn.CrossEntropyLoss(),
    optimizer=torch.optim.Adam,
    traindataloader=train_loader,  # Your DataLoader
    validdataloader=valid_loader,  # Your DataLoader
    scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

# Start training
trainer.loop()

📊 Report Types

The package supports multiple reporting backends:

  • 📈 TENSORBOARD: Real-time training visualization
  • 📝 TOML: Configuration and model architecture serialization. See https://pypi.org/project/tomlserializer/ for details
  • 📊 MLFLOW: Experiment tracking and model management
  • 🔄 RAY: Distributed training support

Configure them in TrainerSettings:

settings = TrainerSettings(
    reporttypes=[ReportTypes.TENSORBOARD, ReportTypes.MLFLOW],
    # ... other settings
)

🔍 Metrics

Built-in metrics include:

  • Accuracy: Classification accuracy
  • MAE: Mean Absolute Error
  • MASE: Mean Absolute Scaled Error (for time series)

Metrics are PyTorch-native and handle device placement automatically:

from metrics import Accuracy, MAE

settings = TrainerSettings(
    metrics=[Accuracy(), MAE()],
    # ... other settings
)

🔄 Preprocessors

Two main preprocessors are available:

  1. BasePreprocessor: Standard batch processing for fixed-size inputs

    preprocessor = BasePreprocessor()
    batch_x, batch_y = preprocessor(batch)
    
  2. PaddedPreprocessor: Handles variable-length sequences with padding

    preprocessor = PaddedPreprocessor()
    padded_x, batch_y = preprocessor(sequence_batch)
    

🧠 Available Models

The package includes several model architectures:

Image Models

  • CNN with configurable filters
  • Neural Network with customizable layers

RNN Models

  • Base RNN
  • GRU with optional attention
  • NLP models with embedding support

Example using AttentionGRU:

config = {
    "input_size": 10,
    "hidden_size": 64,
    "output_size": 1,
    "num_layers": 2,
    "dropout": 0.1
}
model = AttentionGRU(config)

⚙️ Advanced Configuration

TrainerSettings supports comprehensive training configuration:

settings = TrainerSettings(
    epochs=100,
    metrics=[Accuracy()],
    logdir=Path("./experiments"),
    train_steps=500,
    valid_steps=50,
    reporttypes=[ReportTypes.TENSORBOARD, ReportTypes.MLFLOW],
    optimizer_kwargs={
        "lr": 1e-3,
        "weight_decay": 1e-5
    },
    scheduler_kwargs={
        "factor": 0.1,
        "patience": 10
    },
    earlystop_kwargs={
        "save": True,
        "verbose": True,
        "patience": 10
    }
)

🔔 Early Stopping

The trainer includes built-in early stopping with model checkpointing:

settings = TrainerSettings(
    earlystop_kwargs={
        "patience": 7,      # Episodes to wait before stopping
        "save": True,       # Save best model
        "verbose": True,    # Print progress
        "delta": 0.001     # Minimum improvement threshold
    },
    # ... other settings
)

📝 Logging

The package uses loguru for comprehensive logging. All training progress, early stopping events, and potential issues are automatically logged:

from loguru import logger

# Logs are automatically created in your logdir
# Example log message:
# [2024-02-13 14:30:22] INFO: Epoch 5 train 0.3421 test 0.2891 metric [0.8934]

🤝 Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mltrainer-0.2.2-py3-none-any.whl (13.1 kB view details)

Uploaded Python 3

File details

Details for the file mltrainer-0.2.2-py3-none-any.whl.

File metadata

  • Download URL: mltrainer-0.2.2-py3-none-any.whl
  • Upload date:
  • Size: 13.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.6.8

File hashes

Hashes for mltrainer-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 79fe77c150c1437b9a111cc9c2ebe9a0d007d3cd88eb03b905af342b180f8537
MD5 866239f8f1f08fb26db5878ac919b885
BLAKE2b-256 979b3aa8f991c9ab7afda5b1cb16c44ed410f2d22dc49f66fc5488fe939102d4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page