Skip to main content

A Clean and Modular PyTorch Training Framework

Project description

๐Ÿƒโ€โ™€๏ธโ€โžก๏ธ Treadmill ๐Ÿƒโ€โ™€๏ธโ€โžก๏ธ

Treadmill Training Framework

A Clean and Modular PyTorch Training Framework

Treadmill is a lightweight, modular training framework specifically designed for PyTorch. It provides clean, easy-to-understand training loops with beautiful output formatting while maintaining the power and flexibility of vanilla PyTorch.

โœจ Features

  • ๐ŸŽฏ Pure PyTorch: Built specifically for PyTorch, no forced abstractions
  • ๐Ÿ”ง Modular Design: Easy to customize and extend with callback system
  • ๐Ÿ“Š Beautiful Output: Rich formatting with progress bars and metrics tables
  • โšก Performance Optimizations: Mixed precision, gradient accumulation, gradient clipping
  • ๐ŸŽ›๏ธ Flexible Configuration: Dataclass-based configuration system
  • ๐Ÿ“ˆ Comprehensive Metrics: Built-in metrics with support for custom metrics
  • ๐Ÿ’พ Smart Checkpointing: Automatic model saving with customizable triggers
  • ๐Ÿ›‘ Early Stopping: Configurable early stopping to prevent overfitting
  • ๐Ÿ”„ Resumable Training: Easy checkpoint loading and training resumption

๐Ÿ› ๏ธ Installation

From PyPI (Recommended)

pip install pytorch-treadmill

Install with Optional Dependencies

# With examples dependencies (torchvision, scikit-learn)
pip install "pytorch-treadmill[examples]"

# With full dependencies (visualization tools, docs, etc.)
pip install "pytorch-treadmill[full]"

# For development
pip install "pytorch-treadmill[dev]"

From Source

For the latest development version or to contribute:

git clone https://github.com/MayukhSobo/treadmill.git
cd treadmill
pip install -e .

Install with Examples (Development)

pip install -e ".[examples]"  # Includes torchvision and additional dependencies

Install Full Version (Development)

pip install -e ".[full]"  # Includes all optional dependencies

๐Ÿš€ Quick Start

Here's a minimal example to get you started:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from treadmill import Trainer, TrainingConfig, OptimizerConfig
from treadmill.metrics import StandardMetrics

# Define your model
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(784, 10)
    
    def forward(self, x):
        return self.fc(x.view(x.size(0), -1))

# Prepare your data (DataLoaders)
train_loader = DataLoader(...)  # Your training data
val_loader = DataLoader(...)    # Your validation data

# Configure training
config = TrainingConfig(
    epochs=10,
    optimizer=OptimizerConfig(optimizer_class="Adam", lr=1e-3),
    device="auto"  # Automatically uses GPU if available
)

# Create and run trainer
trainer = Trainer(
    model=SimpleNet(),
    config=config,
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    loss_fn=nn.CrossEntropyLoss(),
    metric_fns={"accuracy": StandardMetrics.accuracy}
)

# Start training
history = trainer.train()

๐Ÿ“– Core Components

TrainingConfig

The main configuration class that controls all aspects of training:

config = TrainingConfig(
    # Basic settings
    epochs=20,
    device="auto",  # "auto", "cpu", "cuda", or specific device
    
    # Optimizer configuration
    optimizer=OptimizerConfig(
        optimizer_class="Adam",  # Any PyTorch optimizer
        lr=1e-3,
        weight_decay=1e-4,
        params={"betas": (0.9, 0.999)}  # Additional optimizer parameters
    ),
    
    # Learning rate scheduler
    scheduler=SchedulerConfig(
        scheduler_class="StepLR",
        params={"step_size": 10, "gamma": 0.1}
    ),
    
    # Training optimizations
    mixed_precision=True,
    grad_clip_norm=1.0,
    accumulate_grad_batches=4,
    
    # Validation and early stopping
    validate_every=1,
    early_stopping_patience=5,
    
    # Display and logging
    print_every=50,
    progress_bar=True
)

Callbacks System

Extend functionality with callbacks:

from treadmill.callbacks import EarlyStopping, ModelCheckpoint, LearningRateLogger

callbacks = [
    EarlyStopping(monitor="val_loss", patience=10, verbose=True),
    ModelCheckpoint(
        filepath="./checkpoints/model_epoch_{epoch:03d}_{val_accuracy:.4f}.pt",
        monitor="val_accuracy",
        mode="max",
        save_best_only=True
    ),
    LearningRateLogger(verbose=True)
]

trainer = Trainer(..., callbacks=callbacks)

Custom Metrics

Define your own metrics or use built-in ones:

from treadmill.metrics import StandardMetrics

# Built-in metrics
metric_fns = {
    "accuracy": StandardMetrics.accuracy,
    "top5_acc": lambda p, t: StandardMetrics.top_k_accuracy(p, t, k=5),
    "f1": StandardMetrics.f1_score
}

# Custom metrics
def custom_metric(predictions, targets):
    # Your custom metric calculation
    return some_value

metric_fns["custom"] = custom_metric

๐Ÿ”ง Advanced Usage

Custom Forward/Backward Functions

For complex models with multiple components or special training procedures:

def custom_forward_fn(model, batch):
    """Custom forward pass for complex models."""
    inputs, targets = batch
    
    # Your custom forward logic
    outputs = model(inputs)
    additional_outputs = model.some_other_forward(inputs)
    
    return (outputs, additional_outputs), targets

def custom_backward_fn(loss, model, optimizer):
    """Custom backward pass with special handling."""
    loss.backward()
    # Add any custom gradient processing here

config = TrainingConfig(
    custom_forward_fn=custom_forward_fn,
    custom_backward_fn=custom_backward_fn,
    # ... other config
)

Model with Built-in Loss

Your model can implement its own loss computation:

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        # ... model definition
    
    def forward(self, x):
        # ... forward pass
        return outputs
    
    def compute_loss(self, outputs, targets):
        """Custom loss computation."""
        return your_loss_calculation(outputs, targets)

# No need to provide loss_fn to trainer
trainer = Trainer(
    model=MyModel(),
    config=config,
    train_dataloader=train_loader,
    # loss_fn=None  # Will use model's compute_loss method
)

Checkpointing and Resuming

# Save checkpoint
trainer.save_checkpoint("my_checkpoint.pt")

# Load checkpoint
trainer.load_checkpoint("my_checkpoint.pt", resume_training=True)

# Or create new trainer and load
new_trainer = Trainer(...)
checkpoint = new_trainer.load_checkpoint("my_checkpoint.pt", resume_training=False)

๐Ÿ“Š Output Examples

Treadmill provides beautiful, informative output during training:

============================================================
๐Ÿš€ Starting Training with Treadmill
============================================================

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                       Model Info                        โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ Model: SimpleCNN                                        โ”‚
โ”‚ Total Parameters: 1.2M                                  โ”‚
โ”‚ Trainable Parameters: 1.2M                              โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

Epoch 1/20
โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
Batch   50/391 ( 12.8%) | loss: 2.1234 | accuracy: 0.2341
Batch  100/391 ( 25.6%) | loss: 1.8765 | accuracy: 0.3456
...

โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”“
โ”ƒ Metric         โ”ƒ Train      โ”ƒ Validation     โ”ƒ
โ”กโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ฉ
โ”‚ Loss           โ”‚ 1.2345     โ”‚ 1.3456         โ”‚
โ”‚ Accuracy       โ”‚ 0.6789     โ”‚ 0.6234         โ”‚
โ”‚ Epoch Time     โ”‚ 2m 34.5s   โ”‚ 2m 34.5s       โ”‚
โ”‚ Total Time     โ”‚ 2m 34.5s   โ”‚ 2m 34.5s       โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

๐ŸŽฏ Examples

Check out the /examples directory for complete examples:

  • basic_training.py: Simple CNN on CIFAR-10
  • advanced_training.py: VAE with custom forward/backward functions

Run examples:

cd examples
python basic_training.py
python advanced_training.py

๐Ÿค Contributing

I welcome contributions! Please see our contributing guidelines for more details.

๐Ÿ“„ License

This project is licensed under the MIT License - see the LICENSE file for details.

๐Ÿ™ Acknowledgments

  • Inspired by the need for clean, modular PyTorch training
  • Built with โค๏ธ for the PyTorch community
  • Uses Rich for beautiful terminal output

Happy Training with Treadmill! ๐Ÿš€

Documentation will be available at: https://mayukhsobo.github.io/treadmill/

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_treadmill-0.3.0.tar.gz (29.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_treadmill-0.3.0-py3-none-any.whl (24.5 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_treadmill-0.3.0.tar.gz.

File metadata

  • Download URL: pytorch_treadmill-0.3.0.tar.gz
  • Upload date:
  • Size: 29.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for pytorch_treadmill-0.3.0.tar.gz
Algorithm Hash digest
SHA256 477d56f0271be0c7594558ea3f40b0378b2e21e34e6b286ccb6cd2a63dd29176
MD5 d4626ce5c7b988389a3c1153c371d362
BLAKE2b-256 b54bd14e0544e7d3959302ad167690c83fba4762bf4f135b26a2f324073ac9ce

See more details on using hashes here.

File details

Details for the file pytorch_treadmill-0.3.0-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_treadmill-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 0e4b5c5586ce0fd2706cbccbebe517e7f933dc1fd2720804327e65342cf40888
MD5 d7b47c7553aa811348505014356f2513
BLAKE2b-256 9c625ed985f1f203f5541929f39dc04f84fcb63722617caa2b6481cd862b179d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page