Intuitive training framework for PyTorch
Project description
blowtorch
Intuitive, high-level training framework for research and development. Abstracts away boilerplate normally associated with training and evaluating PyTorch models, without limiting flexibility. Blowtorch provides the following:
- A way to specify training runs at a high level, while not giving up on fine-grained control over individual training parts
- Automated checkpointing, logging and resuming of runs
- A sacred inspired configuration management
- Reproducibility by keeping track of configuration, code and random state of each run
Installation
Make sure you have numpy
and torch
installed, then install with pip:
pip install --upgrade blowtorch
Example
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from torchvision.models import vgg16
from blowtorch import Run
run = Run(random_seed=123)
@run.train_step
@run.validate_step
def step(batch, model):
x, y = batch
y_hat = model(x)
loss = (y - y_hat).square().mean()
return loss
# will be called when model has been moved to the desired device
@run.configure_optimizers
def configure_optimizers(model):
return Adam(model.parameters())
train_loader = DataLoader(CIFAR10('.', train=True, download=True, transform=ToTensor()))
val_loader = DataLoader(CIFAR10('.', train=False, download=True, transform=ToTensor()))
run(vgg16(num_classes=10), train_loader, val_loader)
Configuration
You can pass multiple configuration files in YAML format to your Run
, e.g.
run = Run(config_files=['config/default.yaml'])
Configuration values can then be accessed via e.g. run['model']['num_layers']
. Dotted notation is also supported, e.g. run['model.num_layers']
. When executing your training script, individual configuration values can be overwritten as follows:
python train.py with model.num_layers=4 model.use_dropout=True
Run options
Run.run()
takes following options:
model
:torch.nn.Module
train_loader
:torch.utils.data.DataLoader
val_loader
:torch.utils.data.DataLoader
loggers
:Optional[List[aurora.logging.BaseLogger]]
(List of loggers that subscribe to various logging events, see logging section)max_epochs
:int
(default1
)use_gpu
:bool
(defaultTrue
)gpu_id
:int
(default0
)resume_checkpoint
:Optional[Union[str, pathlib.Path]]
(Path to checkpoint directory to resume training from, defaultNone
)save_path
:Union[str, pathlib.Path]
(Path to directory that blowtorch will save logs and checkpoints to, default'train_logs'
)run_name
:Optional[str]
(Name associated with that run, will be randomly created if None, defaultNone
)optimize_metric
:Optional[str]
(train metric that will be used for optimization, will pick the first returned one if None, defaultNone
)checkpoint_metric
:Optional[str]
(validation metric that will be used for checkpointing, will pick the first returned one if None, defaultNone
)smaller_is_better
:bool
(defaultTrue
)optimize_first
:bool
(whether optimization should occur during the first epoch, defaultFalse
)detect_anomalies
:bool
(enable autograd anomaly detection, defaultFalse
)
Logging
Blowtorch will create a folder with name "[timestamp]-[name]-[sequential integer]" for each run inside the run.save_path
directory. Here it will save the runs's configuration, metrics, a model summary, checkoints as well as source code. Additional loggers can be added through Run
s loggers
parameter:
blowtorch.loggers.WandbLogger
: Logs to Weights & Biasesblowtorch.loggers.TensorBoardLogger
: Logs to TensorBoard
Custom loggers can be created by subclassing blowtorch.loggers.BaseLogger
.
Decorators
Blowtorch uses the decorator syntax to specify parts of the training pipeline:
@run.train_step
,@run.val_step
: Specify train/val steps with one or two functions. Arguments:batch
,model
,is_validate
,device
,epoch
@run.train_epoch
,@run.val_epoch
: Specify whole train/val epoch, in case more flexibility for iteration/optimization is required. Arguments:data_loader
,model
,is_validate
,optimizers
@run.configure_optimizers
: Return optimizers and learning rate schedulers. Can either return a single optimizer object or a dictionary with multiple optimizers/schedulers. Arguments:model
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file blowtorch-0.5.4.tar.gz
.
File metadata
- Download URL: blowtorch-0.5.4.tar.gz
- Upload date:
- Size: 16.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.62.1 CPython/3.8.11
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 |
194f05c56d02484b912242b1f4e9fdd373d7f122a1980d9462349f139d3dd896
|
|
MD5 |
c1d1d9f778bf70040b49be141aca5b08
|
|
BLAKE2b-256 |
597bd713bdaf37a3eb4fd8c2553329e592ed9f5c217f1e9e7257da9d18d5f412
|
File details
Details for the file blowtorch-0.5.4-py3-none-any.whl
.
File metadata
- Download URL: blowtorch-0.5.4-py3-none-any.whl
- Upload date:
- Size: 18.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.62.1 CPython/3.8.11
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 |
f06daa127649ff37df9a73e0989ddbc4b45290e84ea3ec5e763feab0ae939546
|
|
MD5 |
e4af547c2a179f318effd5d7ab6513cb
|
|
BLAKE2b-256 |
62536d412e31bbdd3864ff0c845dc7b35ac7998a5017fab5deb8ef29d947c367
|