Intuitive training framework for PyTorch
Project description
blowtorch
Intuitive, high-level training framework for research and development. It abstracts away boilerplate normally associated with training and evaluating PyTorch models, without limiting your flexibility. Blowtorch provides the following:
- A way to specify training runs at a high level, while not giving up on fine-grained control over individual training parts
- Automated checkpointing, logging and resuming of runs
- A sacred inspired configuration management
- Reproducibility by keeping track of configuration, code and random state of each run
Installation
Make sure you have numpy
and torch
installed, then install with pip:
pip install --upgrade blowtorch
Minimal working example
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.models import vgg16
from torchvision.datasets import ImageNet
from blowtorch import Run
run = Run(random_seed=123)
@run.train_step
@run.validate_step
def step(batch, model):
x, y = batch
y_hat = model(x)
loss = (y - y_hat) ** 2
return loss
@run.configure_optimizers
def configure_optimizers(model):
return Adam(model.parameters())
# setup data
train_loader = DataLoader(ImageNet('.', split='train'), batch_size=4)
val_loader = DataLoader(ImageNet('.', split='val'), batch_size=4)
run(vgg16(), train_loader, val_loader)
Configuration
You can pass multiple configuration files in YAML format to your Run
, e.g.
run = Run(config_files=['config/default.yaml'])
Configuration values can then be accessed via e.g. run['model']['num_layers']
. When executing your training script, individual configuration values can be overwritten as follows:
python train.py with model.num_layers=4 model.use_dropout=True
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
blowtorch-0.2.2.tar.gz
(13.0 kB
view hashes)
Built Distribution
blowtorch-0.2.2-py3-none-any.whl
(15.8 kB
view hashes)
Close
Hashes for blowtorch-0.2.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | de2e37c92e02fc00fa152d1cafb4400f78a096432cb651e8a1947ef0943a9b8b |
|
MD5 | 05ddb3f273300fc6eee19867d0aef87d |
|
BLAKE2b-256 | e3c210c9a810c2a679b502385dc5fe87b48dfc2f37ca8739033401ac438662a5 |