Small utilities to simplify trainining of PyTorch models.
Project description
Trainer Tools
A lightweight, hook-based training loop for PyTorch. trainer-tools abstracts away the boilerplate of training loops while remaining fully customizable via a powerful flexible hook system.
Features
- Hook System: Customize every step of the training lifecycle (before/after batch, step, epoch, fit).
- Built-in Integrations: Comes with hooks for wandb or trackio, Progress Bar, and Checkpointing.
- Optimization: Easy Automatic Mixed Precision (AMP), Gradient Accumulation, and Gradient Clipping.
- Metrics: robust metric tracking and logging to JSONL or external trackers.
- Memory Profiling: Built-in tools to debug CUDA memory leaks.
Installation
pip install trainer-tools
# With optional integrations
pip install trainer-tools[wandb] # Weights & Biases logging
pip install trainer-tools[trackio] # Trackio logging
pip install trainer-tools[hydra] # Hydra config management
pip install trainer-tools[all] # All optional dependencies
Quick Start
Here is a minimal example of training a simple model:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from trainer_tools.trainer import Trainer
from trainer_tools.hooks import MetricsHook, Accuracy, Loss, ProgressBarHook
# 1. Prepare Data
x = torch.randn(100, 10)
y = torch.randint(0, 2, (100,))
ds = TensorDataset(x, y)
dl = DataLoader(ds, batch_size=32)
# 2. Define Model
model = nn.Sequential(nn.Linear(10, 2))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# 3. Setup Hooks
metrics = MetricsHook(metrics=[Accuracy(), Loss()])
pbar = ProgressBarHook()
# 4. Train
trainer = Trainer(
model=model,
train_dl=dl,
valid_dl=dl,
optim=optimizer,
loss_func=nn.CrossEntropyLoss(),
epochs=5,
hooks=[metrics, pbar],
device="cuda" if torch.cuda.is_available() else "cpu"
)
trainer.fit()
How the Trainer Works
Training Loop
Trainer.fit() runs a standard PyTorch training loop but exposes hook points at every meaningful stage so you can inject logic without touching the core loop:
fit()
├── before_fit
└── for each epoch:
├── before_epoch
├── for each training batch:
│ ├── before_step
│ ├── predict() → trainer.preds
│ ├── after_pred
│ ├── get_loss() → trainer.loss / trainer.loss_t # float loss / tensor loss
│ ├── after_loss
│ ├── loss_t.backward()
│ ├── after_backward
│ ├── opt.step() / opt.zero_grad()
│ └── after_step
├── before_valid
├── for each validation batch:
│ └── (same as above, no backward/opt steps)
└── after_epoch
└── after_fit (or after_cancel on KeyboardInterrupt)
Key Trainer Attributes
At any hook point you have access to the live trainer state:
| Attribute | Description |
|---|---|
trainer.model |
The nn.Module being trained |
trainer.opt |
The optimizer |
trainer.epoch / trainer.step |
Current epoch / global step count |
trainer.batch |
The raw batch from the dataloader |
trainer.preds |
Model predictions (set after predict()) |
trainer.loss / trainer.loss_t |
Scalar loss value / loss tensor |
trainer.training |
True during the train phase, False during validation |
trainer.dl |
DataLoader currently in use |
trainer.config |
Optional config object (e.g. Hydra DictConfig) |
Three boolean flags let a hook short-circuit the default behavior for a single step:
| Flag | Effect when set to True |
|---|---|
trainer.skip_backward |
Skips loss.backward() |
trainer.skip_opt_step |
Skips opt.step() |
trainer.skip_zero_grad |
Skips opt.zero_grad() |
Hook Execution Order
Hooks are sorted by their ord attribute before each call. Lower values run first. The default is 0. This guarantees correct ordering when hooks depend on one another (e.g., CheckpointHook runs before LRSchedulerHook so a restored scheduler state is used from the beginning).
The Hook System
trainer-tools relies on BaseHook. You can create custom behavior by subclassing it:
from trainer_tools.hooks import BaseHook
class MyCustomHook(BaseHook):
def after_step(self, trainer):
if trainer.step % 100 == 0:
print(f"Current Loss: {trainer.loss}")
Available Hooks
ProgressBarHook
Displays tqdm progress bars for epochs and batches. Shows a running training loss that updates every freq steps.
from trainer_tools.hooks import ProgressBarHook
pbar = ProgressBarHook(freq=10)
MetricsHook
Central hub for computing, aggregating, and logging metrics. Supports logging to the console, a JSONL history file, Weights & Biases, or Trackio.
from trainer_tools.hooks import MetricsHook, Loss
from trainer_tools.hooks.metrics import Accuracy
metrics = MetricsHook(
metrics=[Loss(), Accuracy()],
tracker_type="wandb", # or "trackio" / None
project="my-project",
name="run_1"
)
Metrics are split into phases ("step" or "epoch") and automatically prefixed with train_ / valid_.
LRSchedulerHook
Wraps any PyTorch LRScheduler and calls sched.step() after every optimizer update.
from trainer_tools.hooks import LRSchedulerHook
hook = LRSchedulerHook(
sched_fn=lambda opt: torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=1000)
)
AMPHook
Enables Automatic Mixed Precision training by wrapping the forward pass in torch.autocast and managing a GradScaler. Supports both float16 and bfloat16.
from trainer_tools.hooks import AMPHook
amp = AMPHook(dtype=torch.bfloat16, device_type="cuda")
GradClipHook
Clips gradient norms before each optimizer step to stabilize training.
from trainer_tools.hooks import GradClipHook
clip = GradClipHook(max_norm=1.0)
GradientAccumulationHook
Accumulates gradients over multiple micro-batches before calling opt.step(), effectively increasing the batch size without extra GPU memory.
from trainer_tools.hooks import GradientAccumulationHook
accum = GradientAccumulationHook(accumulation_steps=4)
CheckpointHook
Saves and restores model, optimizer, scheduler, scaler, and RNG states. Supports both a "best" (by a tracked metric) and "latest" save strategy, and keeps a configurable number of recent checkpoints.
from trainer_tools.hooks import CheckpointHook
ckpt = CheckpointHook(
save_dir="checkpoints/",
save_every_steps=500,
keep_last=3,
save_strategy="best", # or "latest"
metric_name="valid_loss",
resume_path="checkpoints/step_1000", # optional
)
EMAHook
Maintains an exponential moving average (EMA) of model weights. Validation is automatically run against the EMA model, which often gives better generalization. The EMA state is saved and restored with CheckpointHook.
from trainer_tools.hooks import EMAHook
ema = EMAHook(decay=0.9999)
BatchTransformHook
Applies on-GPU data augmentations or pre-processing transforms to inputs and/or targets at the start of each batch. Separate transforms can be provided for training and validation.
from trainer_tools.hooks import BatchTransformHook
import torchvision.transforms.v2 as T
aug = BatchTransformHook(
x_tfm=T.RandomHorizontalFlip(),
x_tfms_valid=None, # no aug during validation
)
AccelerateHook
Integrates HuggingFace Accelerate for distributed training (DDP/FSDP), mixed precision, and gradient accumulation in a single hook. When used, do not add AMPHook, GradClipHook, or GradientAccumulationHook — Accelerate handles all of that.
from trainer_tools.hooks.accelerate import AccelerateHook
accel = AccelerateHook(
gradient_accumulation_steps=4,
max_grad_norm=1.0,
mixed_precision="bf16",
)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file trainer_tools-0.3.0.tar.gz.
File metadata
- Download URL: trainer_tools-0.3.0.tar.gz
- Upload date:
- Size: 133.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1ed64fd1247982f253600036c1c817d752df61064c9a947fb3797d99f1e6e222
|
|
| MD5 |
47326b63a56395361202d8021150b8fb
|
|
| BLAKE2b-256 |
c8c9bb730f282783f32027df673c21bbfbfe024f39458ba9b37d5690a7953b87
|
Provenance
The following attestation bundles were made for trainer_tools-0.3.0.tar.gz:
Publisher:
publish.yml on ssslakter/trainer-tools
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
trainer_tools-0.3.0.tar.gz -
Subject digest:
1ed64fd1247982f253600036c1c817d752df61064c9a947fb3797d99f1e6e222 - Sigstore transparency entry: 1011767325
- Sigstore integration time:
-
Permalink:
ssslakter/trainer-tools@0f5ef95f51a61f9ab49f9b45f09abae2be085a3f -
Branch / Tag:
refs/heads/main - Owner: https://github.com/ssslakter
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@0f5ef95f51a61f9ab49f9b45f09abae2be085a3f -
Trigger Event:
workflow_dispatch
-
Statement type:
File details
Details for the file trainer_tools-0.3.0-py3-none-any.whl.
File metadata
- Download URL: trainer_tools-0.3.0-py3-none-any.whl
- Upload date:
- Size: 28.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9c2774262cc0c7febf3fe63632e50a286f814b9a1aef4e8908729f3752ff797c
|
|
| MD5 |
60dbad683d07ef2847ce4da3b027f192
|
|
| BLAKE2b-256 |
67a6b1934295dc8985c003f5e3cb42a8ed7b506ec1b7f7b003b2900a33236a62
|
Provenance
The following attestation bundles were made for trainer_tools-0.3.0-py3-none-any.whl:
Publisher:
publish.yml on ssslakter/trainer-tools
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
trainer_tools-0.3.0-py3-none-any.whl -
Subject digest:
9c2774262cc0c7febf3fe63632e50a286f814b9a1aef4e8908729f3752ff797c - Sigstore transparency entry: 1011767372
- Sigstore integration time:
-
Permalink:
ssslakter/trainer-tools@0f5ef95f51a61f9ab49f9b45f09abae2be085a3f -
Branch / Tag:
refs/heads/main - Owner: https://github.com/ssslakter
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@0f5ef95f51a61f9ab49f9b45f09abae2be085a3f -
Trigger Event:
workflow_dispatch
-
Statement type: