A Clean and Modular PyTorch Training Framework
Project description
๐โโ๏ธโโก๏ธ Treadmill ๐โโ๏ธโโก๏ธ
A Clean and Modular PyTorch Training Framework
Treadmill is a lightweight, modular training framework specifically designed for PyTorch. It provides clean, easy-to-understand training loops with beautiful output formatting while maintaining the power and flexibility of vanilla PyTorch.
โจ Features
- ๐ฏ Pure PyTorch: Built specifically for PyTorch, no forced abstractions
- ๐ง Modular Design: Easy to customize and extend with callback system
- ๐ Beautiful Output: Rich formatting with progress bars and metrics tables
- โก Performance Optimizations: Mixed precision, gradient accumulation, gradient clipping
- ๐๏ธ Flexible Configuration: Dataclass-based configuration system
- ๐ Comprehensive Metrics: Built-in metrics with support for custom metrics
- ๐พ Smart Checkpointing: Automatic model saving with customizable triggers
- ๐ Early Stopping: Configurable early stopping to prevent overfitting
- ๐ Resumable Training: Easy checkpoint loading and training resumption
๐ ๏ธ Installation
From PyPI (Recommended)
pip install pytorch-treadmill
Install with Optional Dependencies
# With examples dependencies (torchvision, scikit-learn)
pip install "pytorch-treadmill[examples]"
# With full dependencies (visualization tools, docs, etc.)
pip install "pytorch-treadmill[full]"
# For development
pip install "pytorch-treadmill[dev]"
From Source
For the latest development version or to contribute:
git clone https://github.com/MayukhSobo/treadmill.git
cd treadmill
pip install -e .
Install with Examples (Development)
pip install -e ".[examples]" # Includes torchvision and additional dependencies
Install Full Version (Development)
pip install -e ".[full]" # Includes all optional dependencies
๐ Quick Start
Here's a minimal example to get you started:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from treadmill import Trainer, TrainingConfig, OptimizerConfig
from treadmill.metrics import StandardMetrics
# Define your model
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(784, 10)
def forward(self, x):
return self.fc(x.view(x.size(0), -1))
# Prepare your data (DataLoaders)
train_loader = DataLoader(...) # Your training data
val_loader = DataLoader(...) # Your validation data
# Configure training
config = TrainingConfig(
epochs=10,
optimizer=OptimizerConfig(optimizer_class="Adam", lr=1e-3),
device="auto" # Automatically uses GPU if available
)
# Create and run trainer
trainer = Trainer(
model=SimpleNet(),
config=config,
train_dataloader=train_loader,
val_dataloader=val_loader,
loss_fn=nn.CrossEntropyLoss(),
metric_fns={"accuracy": StandardMetrics.accuracy}
)
# Start training
history = trainer.train()
๐ Core Components
TrainingConfig
The main configuration class that controls all aspects of training:
config = TrainingConfig(
# Basic settings
epochs=20,
device="auto", # "auto", "cpu", "cuda", or specific device
# Optimizer configuration
optimizer=OptimizerConfig(
optimizer_class="Adam", # Any PyTorch optimizer
lr=1e-3,
weight_decay=1e-4,
params={"betas": (0.9, 0.999)} # Additional optimizer parameters
),
# Learning rate scheduler
scheduler=SchedulerConfig(
scheduler_class="StepLR",
params={"step_size": 10, "gamma": 0.1}
),
# Training optimizations
mixed_precision=True,
grad_clip_norm=1.0,
accumulate_grad_batches=4,
# Validation and early stopping
validate_every=1,
early_stopping_patience=5,
# Display and logging
print_every=50,
progress_bar=True
)
Callbacks System
Extend functionality with callbacks:
from treadmill.callbacks import EarlyStopping, ModelCheckpoint, LearningRateLogger
callbacks = [
EarlyStopping(monitor="val_loss", patience=10, verbose=True),
ModelCheckpoint(
filepath="./checkpoints/model_epoch_{epoch:03d}_{val_accuracy:.4f}.pt",
monitor="val_accuracy",
mode="max",
save_best_only=True
),
LearningRateLogger(verbose=True)
]
trainer = Trainer(..., callbacks=callbacks)
Custom Metrics
Define your own metrics or use built-in ones:
from treadmill.metrics import StandardMetrics
# Built-in metrics
metric_fns = {
"accuracy": StandardMetrics.accuracy,
"top5_acc": lambda p, t: StandardMetrics.top_k_accuracy(p, t, k=5),
"f1": StandardMetrics.f1_score
}
# Custom metrics
def custom_metric(predictions, targets):
# Your custom metric calculation
return some_value
metric_fns["custom"] = custom_metric
๐ง Advanced Usage
Custom Forward/Backward Functions
For complex models with multiple components or special training procedures:
def custom_forward_fn(model, batch):
"""Custom forward pass for complex models."""
inputs, targets = batch
# Your custom forward logic
outputs = model(inputs)
additional_outputs = model.some_other_forward(inputs)
return (outputs, additional_outputs), targets
def custom_backward_fn(loss, model, optimizer):
"""Custom backward pass with special handling."""
loss.backward()
# Add any custom gradient processing here
config = TrainingConfig(
custom_forward_fn=custom_forward_fn,
custom_backward_fn=custom_backward_fn,
# ... other config
)
Model with Built-in Loss
Your model can implement its own loss computation:
class MyModel(nn.Module):
def __init__(self):
super().__init__()
# ... model definition
def forward(self, x):
# ... forward pass
return outputs
def compute_loss(self, outputs, targets):
"""Custom loss computation."""
return your_loss_calculation(outputs, targets)
# No need to provide loss_fn to trainer
trainer = Trainer(
model=MyModel(),
config=config,
train_dataloader=train_loader,
# loss_fn=None # Will use model's compute_loss method
)
Checkpointing and Resuming
# Save checkpoint
trainer.save_checkpoint("my_checkpoint.pt")
# Load checkpoint
trainer.load_checkpoint("my_checkpoint.pt", resume_training=True)
# Or create new trainer and load
new_trainer = Trainer(...)
checkpoint = new_trainer.load_checkpoint("my_checkpoint.pt", resume_training=False)
๐ Output Examples
Treadmill provides beautiful, informative output during training:
============================================================
๐ Starting Training with Treadmill
============================================================
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Model Info โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ Model: SimpleCNN โ
โ Total Parameters: 1.2M โ
โ Trainable Parameters: 1.2M โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
Epoch 1/20
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
Batch 50/391 ( 12.8%) | loss: 2.1234 | accuracy: 0.2341
Batch 100/391 ( 25.6%) | loss: 1.8765 | accuracy: 0.3456
...
โโโโโโโโโโโโโโโโโโณโโโโโโโโโโโโโณโโโโโโโโโโโโโโโโโ
โ Metric โ Train โ Validation โ
โกโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฉ
โ Loss โ 1.2345 โ 1.3456 โ
โ Accuracy โ 0.6789 โ 0.6234 โ
โ Epoch Time โ 2m 34.5s โ 2m 34.5s โ
โ Total Time โ 2m 34.5s โ 2m 34.5s โ
โโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโ
๐ฏ Examples
Check out the /examples directory for complete examples:
basic_training.py: Simple CNN on CIFAR-10advanced_training.py: VAE with custom forward/backward functions
Run examples:
cd examples
python basic_training.py
python advanced_training.py
๐ค Contributing
I welcome contributions! Please see our contributing guidelines for more details.
๐ License
This project is licensed under the MIT License - see the LICENSE file for details.
๐ Acknowledgments
- Inspired by the need for clean, modular PyTorch training
- Built with โค๏ธ for the PyTorch community
- Uses Rich for beautiful terminal output
Happy Training with Treadmill! ๐
Documentation will be available at: https://mayukhsobo.github.io/treadmill/
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pytorch_treadmill-0.3.0.tar.gz.
File metadata
- Download URL: pytorch_treadmill-0.3.0.tar.gz
- Upload date:
- Size: 29.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
477d56f0271be0c7594558ea3f40b0378b2e21e34e6b286ccb6cd2a63dd29176
|
|
| MD5 |
d4626ce5c7b988389a3c1153c371d362
|
|
| BLAKE2b-256 |
b54bd14e0544e7d3959302ad167690c83fba4762bf4f135b26a2f324073ac9ce
|
File details
Details for the file pytorch_treadmill-0.3.0-py3-none-any.whl.
File metadata
- Download URL: pytorch_treadmill-0.3.0-py3-none-any.whl
- Upload date:
- Size: 24.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0e4b5c5586ce0fd2706cbccbebe517e7f933dc1fd2720804327e65342cf40888
|
|
| MD5 |
d7b47c7553aa811348505014356f2513
|
|
| BLAKE2b-256 |
9c625ed985f1f203f5541929f39dc04f84fcb63722617caa2b6481cd862b179d
|