Minimal PyTorch training loop with hooks and checkpointing.
Project description
trainloop
Minimal PyTorch training loop with hooks for logging, checkpointing, and customization.
Docs: https://karimknaebel.github.io/trainloop/
Install
pip install trainloop
Basic example
import logging
import torch
import torch.nn as nn
from trainloop import BaseTrainer, CheckpointingHook, ProgressHook
logging.basicConfig(level=logging.INFO)
class MyTrainer(BaseTrainer):
def build_data_loader(self):
class ToyDataset(torch.utils.data.IterableDataset):
def __iter__(self):
while True:
data = torch.randn(784)
target = torch.randint(0, 10, (1,)).item()
yield data, target
return torch.utils.data.DataLoader(ToyDataset(), batch_size=32)
def build_model(self):
return nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 10),
).to(self.device)
def build_optimizer(self):
return torch.optim.AdamW(self.model.parameters(), lr=3e-4)
def build_hooks(self):
return [
ProgressHook(interval=50, with_records=True),
CheckpointingHook(interval=500, keep_previous=2),
]
def forward(self, batch):
x, y = batch
x, y = x.to(self.device), y.to(self.device)
logits = self.model(x)
loss = nn.functional.cross_entropy(logits, y)
accuracy = (logits.argmax(1) == y).float().mean().item()
return loss, {"accuracy": accuracy}
trainer = MyTrainer(max_steps=2000, device="cpu", workspace="runs/demo")
trainer.train()
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
trainloop-0.5.2.tar.gz
(13.3 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
trainloop-0.5.2-py3-none-any.whl
(15.2 kB
view details)
File details
Details for the file trainloop-0.5.2.tar.gz.
File metadata
- Download URL: trainloop-0.5.2.tar.gz
- Upload date:
- Size: 13.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.11.14 {"installer":{"name":"uv","version":"0.11.14","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
317e3824c11938253b7343e28ca3a8b9616beee0287d95c9d4aa3cbc6c2dcd00
|
|
| MD5 |
00ab22b94f8b391243c0049ad7821ad6
|
|
| BLAKE2b-256 |
755d438e4be1f7d29a690b2f98781a3b2bdbf7ff2e83dd820ba05502e4466b74
|
File details
Details for the file trainloop-0.5.2-py3-none-any.whl.
File metadata
- Download URL: trainloop-0.5.2-py3-none-any.whl
- Upload date:
- Size: 15.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.11.14 {"installer":{"name":"uv","version":"0.11.14","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3b49d7dc61d3b21a9655a466b4b363140406c6a2f0397d50294be221a4072191
|
|
| MD5 |
0a4ab816bad1f781bda1e9128a37ba75
|
|
| BLAKE2b-256 |
bb84025446ee8003eb167eed0bc41ef1bf3da59ec57c43d1e4f13ac98eeca3ce
|