Lightweight PyTorch training utilities (Trainer, Callbacks, Tuner)
Project description
torchflow
A lightweight, dependency-minimal PyTorch training framework providing a clean Trainer API, a set of commonly-used training callbacks, and an Optuna-backed tuner helper for simple hyperparameter searches.
This repository and source package are available at
https://github.com/mobadara/torchflow. The package is published on PyPI under
the name torchflow-core; import it as torchflow in your code.
See the full changelog in CHANGELOG.md for release history.
Contents
- Features
- Installation
- Quick start
- Callbacks
- Tuner (Optuna)
- Examples
- Testing
- Contributing
- License
Features
- Simple, readable
Trainerfor training and validation loops - Callback system with lifecycle hooks (on_train_begin, on_epoch_begin, on_batch_end, on_validation_end, on_epoch_end, on_train_end)
- Built-in callbacks: EarlyStopping, ModelCheckpoint, LearningRateScheduler, ReduceLROnPlateau, CSVLogger, TensorBoardCallback
- Safe, lazy imports for optional heavy dependencies (TensorBoard, Optuna)
- Small Optuna
tunerhelper that builds a new Trainer for each trial using a user-suppliedbuild_fn(trial)
Installation
Install from PyPI (package name is torchflow-core):
pip install torchflow-core
Then import normally:
import torchflow
For development from source:
git clone https://github.com/mobadara/torchflow.git
cd torchflow
pip install -e .[dev]
Optional extras:
- TensorBoard logging:
pip install tensorboard - Hyperparameter tuning:
pip install optuna
Quick start
Minimal training example (pseudo-code):
import torch
from torch import nn, optim
from torchflow.trainer import Trainer
model = nn.Sequential(nn.Linear(10, 1))
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-2)
trainer = Trainer(model, criterion, optimizer, device='cpu')
trainer.train(train_loader, val_loader=val_loader, num_epochs=5)
Callbacks
Callbacks are simple objects with lifecycle hooks that the Trainer calls at
key moments during training. They are passed to Trainer as a list and can
perform logging, checkpointing, learning-rate changes, early stopping, and
more.
Example with TensorBoard logging and early stopping:
from torchflow.callbacks import TensorBoardCallback, EarlyStopping
tb = TensorBoardCallback(log_dir='runs/myrun') # uses a safe SummaryWriter factory
python examples/simple_train.py
trainer = Trainer(model, criterion, optimizer, callbacks=[tb, early])
trainer.train(train_loader, val_loader=val_loader, num_epochs=20)
The library exposes a few convenience callbacks out of the box:
- EarlyStopping
- ModelCheckpoint
- LearningRateScheduler
- ReduceLROnPlateau
- CSVLogger
- TensorBoardCallback
Tuner (Optuna)
torchflow.tuner provides a small wrapper around Optuna. The contract is:
build_fn(trial)should return a dict with at leastmodel,optimizer, andcriterion- Optional keys
device,callbacks,writer,metrics,mlflow_trackingmay also be returned
Example usage:
from torchflow.tuner import tune, example_build_fn
# `example_build_fn` is a tiny helper included for demonstration.
study = tune(example_build_fn, train_loader, val_loader, n_trials=10, num_epochs=3)
The tuner imports Optuna lazily; importing torchflow.tuner does not require
Optuna to be installed. Calling tune() will raise a clear error if Optuna is
missing.
Examples
Run the included example scripts in the examples/ directory:
python examples/simple_train.py
python examples/lr_and_logging.py
python examples/tensorboard_example.py
Note: examples/tensorboard_example.py will try to use TensorBoard; install
python examples/lr_and_logging.py
Testing
Tests use pytest and are located in the tests/ directory. Some tests skip
when optional dependencies (like torch or tensorboard) are not available.
Run tests locally:
pip install -e .[dev]
pytest -q
Contributing
Contributions are welcome. See CONTRIBUTING.md for contribution guidelines,
the project's coding conventions, and testing instructions.
License
This project is released under the terms of the license in the LICENSE file.
By contributing you agree to license your changes under the same terms.
Maintainers
- Author: Muyiwa J. Obadara
- Repository: https://github.com/mobadara/torchflow
If you'd like to contact the maintainer, open an issue or mention the handle on Twitter: @m_obadara
Project & Contact
- GitHub: https://github.com/mobadara/torchflow
- Twitter: https://twitter.com/m_obadara
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchflow_core-0.1.0.tar.gz.
File metadata
- Download URL: torchflow_core-0.1.0.tar.gz
- Upload date:
- Size: 18.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0dcad9414267a21fa1d902165e8dcf959f04857d112a3280dd1b4d7a6f7922ca
|
|
| MD5 |
cf73b3c9df792e4d83336d507f355d98
|
|
| BLAKE2b-256 |
f2dd9da9eb194d8292c60e1f43c212586580df187acf3ed752b899dbb5b55232
|
File details
Details for the file torchflow_core-0.1.0-py3-none-any.whl.
File metadata
- Download URL: torchflow_core-0.1.0-py3-none-any.whl
- Upload date:
- Size: 17.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b0ac8308b749248e70235b311cd2401d3719880c7134a089feadc8e600b5b587
|
|
| MD5 |
b76b9f8aaba3309968c2dac9b75ebe0c
|
|
| BLAKE2b-256 |
90992fc14a84e556f6fb0ebb2e657e80195b5edbd505497a94425b65f581cabf
|