Skip to main content

A Clean and Modular PyTorch Training Framework

Project description

๐Ÿƒโ€โ™€๏ธโ€โžก๏ธ Treadmill ๐Ÿƒโ€โ™€๏ธโ€โžก๏ธ

Treadmill Training Framework

A Clean and Modular PyTorch Training Framework

Treadmill is a lightweight, modular training framework specifically designed for PyTorch. It provides clean, easy-to-understand training loops with beautiful output formatting while maintaining the power and flexibility of vanilla PyTorch.

โœจ Features

  • ๐ŸŽฏ Pure PyTorch: Built specifically for PyTorch, no forced abstractions
  • ๐Ÿ”ง Modular Design: Easy to customize and extend with callback system
  • ๐Ÿ“Š Beautiful Output: Rich formatting with progress bars and metrics tables
  • โšก Performance Optimizations: Mixed precision, gradient accumulation, gradient clipping
  • ๐ŸŽ›๏ธ Flexible Configuration: Dataclass-based configuration system
  • ๐Ÿ“ˆ Comprehensive Metrics: Built-in metrics with support for custom metrics
  • ๐Ÿ’พ Smart Checkpointing: Automatic model saving with customizable triggers
  • ๐Ÿ›‘ Early Stopping: Configurable early stopping to prevent overfitting
  • ๐Ÿ”„ Resumable Training: Easy checkpoint loading and training resumption

๐Ÿ› ๏ธ Installation

From PyPI (Recommended)

pip install pytorch-treadmill

Install with Optional Dependencies

# With examples dependencies (torchvision, scikit-learn)
pip install "pytorch-treadmill[examples]"

# With full dependencies (visualization tools, docs, etc.)
pip install "pytorch-treadmill[full]"

# For development
pip install "pytorch-treadmill[dev]"

From Source

For the latest development version or to contribute:

git clone https://github.com/MayukhSobo/treadmill.git
cd treadmill
pip install -e .

Install with Examples (Development)

pip install -e ".[examples]"  # Includes torchvision and additional dependencies

Install Full Version (Development)

pip install -e ".[full]"  # Includes all optional dependencies

๐Ÿš€ Quick Start

Here's a minimal example to get you started:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from treadmill import Trainer, TrainingConfig, OptimizerConfig
from treadmill.metrics import StandardMetrics

# Define your model
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(784, 10)
    
    def forward(self, x):
        return self.fc(x.view(x.size(0), -1))

# Prepare your data (DataLoaders)
train_loader = DataLoader(...)  # Your training data
val_loader = DataLoader(...)    # Your validation data

# Configure training
config = TrainingConfig(
    epochs=10,
    optimizer=OptimizerConfig(optimizer_class="Adam", lr=1e-3),
    device="auto"  # Automatically uses GPU if available
)

# Create and run trainer
trainer = Trainer(
    model=SimpleNet(),
    config=config,
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    loss_fn=nn.CrossEntropyLoss(),
    metric_fns={"accuracy": StandardMetrics.accuracy}
)

# Start training
history = trainer.train()

๐Ÿ“– Core Components

TrainingConfig

The main configuration class that controls all aspects of training:

config = TrainingConfig(
    # Basic settings
    epochs=20,
    device="auto",  # "auto", "cpu", "cuda", or specific device
    
    # Optimizer configuration
    optimizer=OptimizerConfig(
        optimizer_class="Adam",  # Any PyTorch optimizer
        lr=1e-3,
        weight_decay=1e-4,
        params={"betas": (0.9, 0.999)}  # Additional optimizer parameters
    ),
    
    # Learning rate scheduler
    scheduler=SchedulerConfig(
        scheduler_class="StepLR",
        params={"step_size": 10, "gamma": 0.1}
    ),
    
    # Training optimizations
    mixed_precision=True,
    grad_clip_norm=1.0,
    accumulate_grad_batches=4,
    
    # Validation and early stopping
    validate_every=1,
    early_stopping_patience=5,
    
    # Display and logging
    print_every=50,
    progress_bar=True
)

Callbacks System

Extend functionality with callbacks:

from treadmill.callbacks import EarlyStopping, ModelCheckpoint, LearningRateLogger

callbacks = [
    EarlyStopping(monitor="val_loss", patience=10, verbose=True),
    ModelCheckpoint(
        filepath="./checkpoints/model_epoch_{epoch:03d}_{val_accuracy:.4f}.pt",
        monitor="val_accuracy",
        mode="max",
        save_best_only=True
    ),
    LearningRateLogger(verbose=True)
]

trainer = Trainer(..., callbacks=callbacks)

Custom Metrics

Define your own metrics or use built-in ones:

from treadmill.metrics import StandardMetrics

# Built-in metrics
metric_fns = {
    "accuracy": StandardMetrics.accuracy,
    "top5_acc": lambda p, t: StandardMetrics.top_k_accuracy(p, t, k=5),
    "f1": StandardMetrics.f1_score
}

# Custom metrics
def custom_metric(predictions, targets):
    # Your custom metric calculation
    return some_value

metric_fns["custom"] = custom_metric

๐Ÿ”ง Advanced Usage

Custom Forward/Backward Functions

For complex models with multiple components or special training procedures:

def custom_forward_fn(model, batch):
    """Custom forward pass for complex models."""
    inputs, targets = batch
    
    # Your custom forward logic
    outputs = model(inputs)
    additional_outputs = model.some_other_forward(inputs)
    
    return (outputs, additional_outputs), targets

def custom_backward_fn(loss, model, optimizer):
    """Custom backward pass with special handling."""
    loss.backward()
    # Add any custom gradient processing here

config = TrainingConfig(
    custom_forward_fn=custom_forward_fn,
    custom_backward_fn=custom_backward_fn,
    # ... other config
)

Model with Built-in Loss

Your model can implement its own loss computation:

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        # ... model definition
    
    def forward(self, x):
        # ... forward pass
        return outputs
    
    def compute_loss(self, outputs, targets):
        """Custom loss computation."""
        return your_loss_calculation(outputs, targets)

# No need to provide loss_fn to trainer
trainer = Trainer(
    model=MyModel(),
    config=config,
    train_dataloader=train_loader,
    # loss_fn=None  # Will use model's compute_loss method
)

Checkpointing and Resuming

# Save checkpoint
trainer.save_checkpoint("my_checkpoint.pt")

# Load checkpoint
trainer.load_checkpoint("my_checkpoint.pt", resume_training=True)

# Or create new trainer and load
new_trainer = Trainer(...)
checkpoint = new_trainer.load_checkpoint("my_checkpoint.pt", resume_training=False)

๐Ÿ“Š Output Examples

Treadmill provides beautiful, informative output during training:

============================================================
๐Ÿš€ Starting Training with Treadmill
============================================================

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                       Model Info                        โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ Model: SimpleCNN                                        โ”‚
โ”‚ Total Parameters: 1.2M                                  โ”‚
โ”‚ Trainable Parameters: 1.2M                              โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

Epoch 1/20
โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
Batch   50/391 ( 12.8%) | loss: 2.1234 | accuracy: 0.2341
Batch  100/391 ( 25.6%) | loss: 1.8765 | accuracy: 0.3456
...

โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”“
โ”ƒ Metric         โ”ƒ Train      โ”ƒ Validation     โ”ƒ
โ”กโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ฉ
โ”‚ Loss           โ”‚ 1.2345     โ”‚ 1.3456         โ”‚
โ”‚ Accuracy       โ”‚ 0.6789     โ”‚ 0.6234         โ”‚
โ”‚ Epoch Time     โ”‚ 2m 34.5s   โ”‚ 2m 34.5s       โ”‚
โ”‚ Total Time     โ”‚ 2m 34.5s   โ”‚ 2m 34.5s       โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

๐ŸŽฏ Examples

Check out the /examples directory for complete examples:

  • basic_training.py: Simple CNN on CIFAR-10
  • advanced_training.py: VAE with custom forward/backward functions

Run examples:

cd examples
python basic_training.py
python advanced_training.py

๐Ÿค Contributing

I welcome contributions! Please see our contributing guidelines for more details.

๐Ÿ“„ License

This project is licensed under the MIT License - see the LICENSE file for details.

๐Ÿ™ Acknowledgments

  • Inspired by the need for clean, modular PyTorch training
  • Built with โค๏ธ for the PyTorch community
  • Uses Rich for beautiful terminal output

Happy Training with Treadmill! ๐Ÿš€

Documentation will be available at: https://mayukhsobo.github.io/treadmill/

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_treadmill-0.4.1.tar.gz (29.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_treadmill-0.4.1-py3-none-any.whl (24.5 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_treadmill-0.4.1.tar.gz.

File metadata

  • Download URL: pytorch_treadmill-0.4.1.tar.gz
  • Upload date:
  • Size: 29.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for pytorch_treadmill-0.4.1.tar.gz
Algorithm Hash digest
SHA256 9ba1e62d89b923a004499ad77f0918685389a2bebf5c67ef29911dcb2d883d55
MD5 6278f1d60c88974ba362da46e329b475
BLAKE2b-256 9343bde2ff91433908eba852bb97193ab2ebfba779cef232660dcff5dc319c87

See more details on using hashes here.

File details

Details for the file pytorch_treadmill-0.4.1-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_treadmill-0.4.1-py3-none-any.whl
Algorithm Hash digest
SHA256 0931001ffd40f80d5e60713d652475d623e75c7c72e4b62b4ef6ddca67d8b854
MD5 97c700f443aa46645e277bba33d3dd56
BLAKE2b-256 48d42b4d1c04f0617be012d91c8a164814fa1d5ca3b0d160aba7ccb49b29f507

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page