Skip to main content

A Clean and Modular PyTorch Training Framework

Project description

๐Ÿƒโ€โ™€๏ธโ€โžก๏ธ Treadmill ๐Ÿƒโ€โ™€๏ธโ€โžก๏ธ

Treadmill Training Framework

A Clean and Modular PyTorch Training Framework

Treadmill is a lightweight, modular training framework specifically designed for PyTorch. It provides clean, easy-to-understand training loops with beautiful output formatting while maintaining the power and flexibility of vanilla PyTorch.

โœจ Features

  • ๐ŸŽฏ Pure PyTorch: Built specifically for PyTorch, no forced abstractions
  • ๐Ÿ”ง Modular Design: Easy to customize and extend with callback system
  • ๐Ÿ“Š Beautiful Output: Rich formatting with progress bars and metrics tables
  • โšก Performance Optimizations: Mixed precision, gradient accumulation, gradient clipping
  • ๐ŸŽ›๏ธ Flexible Configuration: Dataclass-based configuration system
  • ๐Ÿ“ˆ Comprehensive Metrics: Built-in metrics with support for custom metrics
  • ๐Ÿ’พ Smart Checkpointing: Automatic model saving with customizable triggers
  • ๐Ÿ›‘ Early Stopping: Configurable early stopping to prevent overfitting
  • ๐Ÿ”„ Resumable Training: Easy checkpoint loading and training resumption

๐Ÿ› ๏ธ Installation

From PyPI (Recommended)

pip install pytorch-treadmill

Install with Optional Dependencies

# With examples dependencies (torchvision, scikit-learn)
pip install "pytorch-treadmill[examples]"

# With full dependencies (visualization tools, docs, etc.)
pip install "pytorch-treadmill[full]"

# For development
pip install "pytorch-treadmill[dev]"

From Source

For the latest development version or to contribute:

git clone https://github.com/MayukhSobo/treadmill.git
cd treadmill
pip install -e .

Install with Examples (Development)

pip install -e ".[examples]"  # Includes torchvision and additional dependencies

Install Full Version (Development)

pip install -e ".[full]"  # Includes all optional dependencies

๐Ÿš€ Quick Start

Here's a minimal example to get you started:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from treadmill import Trainer, TrainingConfig, OptimizerConfig
from treadmill.metrics import StandardMetrics

# Define your model
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(784, 10)
    
    def forward(self, x):
        return self.fc(x.view(x.size(0), -1))

# Prepare your data (DataLoaders)
train_loader = DataLoader(...)  # Your training data
val_loader = DataLoader(...)    # Your validation data

# Configure training
config = TrainingConfig(
    epochs=10,
    optimizer=OptimizerConfig(optimizer_class="Adam", lr=1e-3),
    device="auto"  # Automatically uses GPU if available
)

# Create and run trainer
trainer = Trainer(
    model=SimpleNet(),
    config=config,
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    loss_fn=nn.CrossEntropyLoss(),
    metric_fns={"accuracy": StandardMetrics.accuracy}
)

# Start training
history = trainer.train()

๐Ÿ“– Core Components

TrainingConfig

The main configuration class that controls all aspects of training:

config = TrainingConfig(
    # Basic settings
    epochs=20,
    device="auto",  # "auto", "cpu", "cuda", or specific device
    
    # Optimizer configuration
    optimizer=OptimizerConfig(
        optimizer_class="Adam",  # Any PyTorch optimizer
        lr=1e-3,
        weight_decay=1e-4,
        params={"betas": (0.9, 0.999)}  # Additional optimizer parameters
    ),
    
    # Learning rate scheduler
    scheduler=SchedulerConfig(
        scheduler_class="StepLR",
        params={"step_size": 10, "gamma": 0.1}
    ),
    
    # Training optimizations
    mixed_precision=True,
    grad_clip_norm=1.0,
    accumulate_grad_batches=4,
    
    # Validation and early stopping
    validate_every=1,
    early_stopping_patience=5,
    
    # Display and logging
    print_every=50,
    progress_bar=True
)

Callbacks System

Extend functionality with callbacks:

from treadmill.callbacks import EarlyStopping, ModelCheckpoint, LearningRateLogger

callbacks = [
    EarlyStopping(monitor="val_loss", patience=10, verbose=True),
    ModelCheckpoint(
        filepath="./checkpoints/model_epoch_{epoch:03d}_{val_accuracy:.4f}.pt",
        monitor="val_accuracy",
        mode="max",
        save_best_only=True
    ),
    LearningRateLogger(verbose=True)
]

trainer = Trainer(..., callbacks=callbacks)

Custom Metrics

Define your own metrics or use built-in ones:

from treadmill.metrics import StandardMetrics

# Built-in metrics
metric_fns = {
    "accuracy": StandardMetrics.accuracy,
    "top5_acc": lambda p, t: StandardMetrics.top_k_accuracy(p, t, k=5),
    "f1": StandardMetrics.f1_score
}

# Custom metrics
def custom_metric(predictions, targets):
    # Your custom metric calculation
    return some_value

metric_fns["custom"] = custom_metric

๐Ÿ”ง Advanced Usage

Custom Forward/Backward Functions

For complex models with multiple components or special training procedures:

def custom_forward_fn(model, batch):
    """Custom forward pass for complex models."""
    inputs, targets = batch
    
    # Your custom forward logic
    outputs = model(inputs)
    additional_outputs = model.some_other_forward(inputs)
    
    return (outputs, additional_outputs), targets

def custom_backward_fn(loss, model, optimizer):
    """Custom backward pass with special handling."""
    loss.backward()
    # Add any custom gradient processing here

config = TrainingConfig(
    custom_forward_fn=custom_forward_fn,
    custom_backward_fn=custom_backward_fn,
    # ... other config
)

Model with Built-in Loss

Your model can implement its own loss computation:

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        # ... model definition
    
    def forward(self, x):
        # ... forward pass
        return outputs
    
    def compute_loss(self, outputs, targets):
        """Custom loss computation."""
        return your_loss_calculation(outputs, targets)

# No need to provide loss_fn to trainer
trainer = Trainer(
    model=MyModel(),
    config=config,
    train_dataloader=train_loader,
    # loss_fn=None  # Will use model's compute_loss method
)

Checkpointing and Resuming

# Save checkpoint
trainer.save_checkpoint("my_checkpoint.pt")

# Load checkpoint
trainer.load_checkpoint("my_checkpoint.pt", resume_training=True)

# Or create new trainer and load
new_trainer = Trainer(...)
checkpoint = new_trainer.load_checkpoint("my_checkpoint.pt", resume_training=False)

๐Ÿ“Š Output Examples

Treadmill provides beautiful, informative output during training:

============================================================
๐Ÿš€ Starting Training with Treadmill
============================================================

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                       Model Info                        โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ Model: SimpleCNN                                        โ”‚
โ”‚ Total Parameters: 1.2M                                  โ”‚
โ”‚ Trainable Parameters: 1.2M                              โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

Epoch 1/20
โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
Batch   50/391 ( 12.8%) | loss: 2.1234 | accuracy: 0.2341
Batch  100/391 ( 25.6%) | loss: 1.8765 | accuracy: 0.3456
...

โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”“
โ”ƒ Metric         โ”ƒ Train      โ”ƒ Validation     โ”ƒ
โ”กโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ฉ
โ”‚ Loss           โ”‚ 1.2345     โ”‚ 1.3456         โ”‚
โ”‚ Accuracy       โ”‚ 0.6789     โ”‚ 0.6234         โ”‚
โ”‚ Epoch Time     โ”‚ 2m 34.5s   โ”‚ 2m 34.5s       โ”‚
โ”‚ Total Time     โ”‚ 2m 34.5s   โ”‚ 2m 34.5s       โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

๐ŸŽฏ Examples

Check out the /examples directory for complete examples:

  • basic_training.py: Simple CNN on CIFAR-10
  • advanced_training.py: VAE with custom forward/backward functions

Run examples:

cd examples
python basic_training.py
python advanced_training.py

๐Ÿค Contributing

I welcome contributions! Please see our contributing guidelines for more details.

๐Ÿ“„ License

This project is licensed under the MIT License - see the LICENSE file for details.

๐Ÿ™ Acknowledgments

  • Inspired by the need for clean, modular PyTorch training
  • Built with โค๏ธ for the PyTorch community
  • Uses Rich for beautiful terminal output

Happy Training with Treadmill! ๐Ÿš€

Documentation will be available at: https://mayukhsobo.github.io/treadmill/

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_treadmill-0.2.3.tar.gz (29.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_treadmill-0.2.3-py3-none-any.whl (24.5 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_treadmill-0.2.3.tar.gz.

File metadata

  • Download URL: pytorch_treadmill-0.2.3.tar.gz
  • Upload date:
  • Size: 29.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for pytorch_treadmill-0.2.3.tar.gz
Algorithm Hash digest
SHA256 d893f8c120b62e42459bcaa82ab580818e0eb9848034cf87d9a91ca7dafca7f8
MD5 34dcf59972fde7f059c301c58e9b8dcb
BLAKE2b-256 181523cf37fe0b92cf37b9586e1982041e4e8ad2b5298613d1aed52fdfe56d6b

See more details on using hashes here.

File details

Details for the file pytorch_treadmill-0.2.3-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_treadmill-0.2.3-py3-none-any.whl
Algorithm Hash digest
SHA256 88ca779f4314987cdd78d02a55fcde44699fefc71df10be72f64fa77baddf16f
MD5 78ad0e0d4ee6044adf8c64370e73d6fc
BLAKE2b-256 bb5fee957f0255752d82127c15263c9f60cb469e4e98b8ad7b024c5c367d46bf

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page