Intuitive training framework for PyTorch
Project description
blowtorch
Intuitive, high-level training framework for research and development. Abstracts away boilerplate normally associated with training and evaluating PyTorch models, without limiting flexibility. Blowtorch provides the following:
- A way to specify training runs at a high level, while not giving up on fine-grained control over individual training parts
- Automated checkpointing, logging and resuming of runs
- A sacred inspired configuration management
- Reproducibility by keeping track of configuration, code and random state of each run
Installation
Make sure you have numpy
and torch
installed, then install with pip:
pip install --upgrade blowtorch
Example
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from torchvision.models import vgg16
from blowtorch import Run
run = Run(random_seed=123)
@run.train_step
@run.validate_step
def step(batch, model):
x, y = batch
y_hat = model(x)
loss = (y - y_hat).square().mean()
return loss
# will be called when model has been moved to the desired device
@run.configure_optimizers
def configure_optimizers(model):
return Adam(model.parameters())
train_loader = DataLoader(CIFAR10('.', train=True, download=True, transform=ToTensor()))
val_loader = DataLoader(CIFAR10('.', train=False, download=True, transform=ToTensor()))
run(vgg16(num_classes=10), train_loader, val_loader)
Configuration
You can pass multiple configuration files in YAML format to your Run
, e.g.
run = Run(config_files=['config/default.yaml'])
Configuration values can then be accessed via e.g. run['model']['num_layers']
. Dotted notation is also supported, e.g. run['model.num_layers']
. When executing your training script, individual configuration values can be overwritten as follows:
python train.py with model.num_layers=4 model.use_dropout=True
Run options
Run.run()
takes following options:
model
:torch.nn.Module
train_loader
:torch.utils.data.DataLoader
val_loader
:torch.utils.data.DataLoader
loggers
:Optional[List[aurora.logging.BaseLogger]]
(List of loggers that subscribe to various logging events, see logging section)max_epochs
:int
(default1
)use_gpu
:bool
(defaultTrue
)gpu_id
:int
(default0
)resume_checkpoint
:Optional[Union[str, pathlib.Path]]
(Path to checkpoint directory to resume training from, defaultNone
)save_path
:Union[str, pathlib.Path]
(Path to directory that blowtorch will save logs and checkpoints to, default'train_logs'
)run_name
:Optional[str]
(Name associated with that run, will be randomly created if None, defaultNone
)optimize_metric
:Optional[str]
(train metric that will be used for optimization, will pick the first returned one if None, defaultNone
)checkpoint_metric
:Optional[str]
(validation metric that will be used for checkpointing, will pick the first returned one if None, defaultNone
)smaller_is_better
:bool
(defaultTrue
)optimize_first
:bool
(whether optimization should occur during the first epoch, defaultFalse
)detect_anomalies
:bool
(enable autograd anomaly detection, defaultFalse
)
Logging
Blowtorch will create a folder with name "[timestamp]-[name]-[sequential integer]" for each run inside the run.save_path
directory. Here it will save the runs's configuration, metrics, a model summary, checkoints as well as source code. Additional loggers can be added through Run
s loggers
parameter:
blowtorch.loggers.WandbLogger
: Logs to Weights & Biasesblowtorch.loggers.TensorBoardLogger
: Logs to TensorBoard
Custom loggers can be created by subclassing blowtorch.loggers.BaseLogger
.
Decorators
Blowtorch uses the decorator syntax to specify parts of the training pipeline:
@run.train_step
,@run.val_step
: Specify train/val steps with one or two functions. Arguments:batch
,model
,is_validate
,device
,epoch
@run.train_epoch
,@run.val_epoch
: Specify whole train/val epoch, in case more flexibility for iteration/optimization is required. Arguments:data_loader
,model
,is_validate
,optimizers
@run.configure_optimizers
: Return optimizers and learning rate schedulers. Can either return a single optimizer object or a dictionary with multiple optimizers/schedulers. Arguments:model
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for blowtorch-0.5.4-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f06daa127649ff37df9a73e0989ddbc4b45290e84ea3ec5e763feab0ae939546 |
|
MD5 | e4af547c2a179f318effd5d7ab6513cb |
|
BLAKE2b-256 | 62536d412e31bbdd3864ff0c845dc7b35ac7998a5017fab5deb8ef29d947c367 |