Skip to main content

Intuitive training framework for PyTorch

Project description

blowtorch

Intuitive, high-level training framework for research and development. Abstracts away boilerplate normally associated with training and evaluating PyTorch models, without limiting flexibility. Blowtorch provides the following:

  • A way to specify training runs at a high level, while not giving up on fine-grained control over individual training parts
  • Automated checkpointing, logging and resuming of runs
  • A sacred inspired configuration management
  • Reproducibility by keeping track of configuration, code and random state of each run

Installation

Make sure you have numpy and torch installed, then install with pip:

pip install --upgrade blowtorch

Example

from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from torchvision.models import vgg16
from blowtorch import Run

run = Run(random_seed=123)

@run.train_step
@run.validate_step
def step(batch, model):
    x, y = batch
    y_hat = model(x)
    loss = (y - y_hat).square().mean()
    return loss

# will be called when model has been moved to the desired device 
@run.configure_optimizers
def configure_optimizers(model):
    return Adam(model.parameters())

train_loader = DataLoader(CIFAR10('.', train=True, download=True, transform=ToTensor()))
val_loader = DataLoader(CIFAR10('.', train=False, download=True, transform=ToTensor()))

run(vgg16(num_classes=10), train_loader, val_loader)

Configuration

You can pass multiple configuration files in YAML format to your Run, e.g.

run = Run(config_files=['config/default.yaml'])

Configuration values can then be accessed via e.g. run['model']['num_layers']. Dotted notation is also supported, e.g. run['model.num_layers']. When executing your training script, individual configuration values can be overwritten as follows:

python train.py with model.num_layers=4 model.use_dropout=True

Run options

Run.run() takes following options:

  • model: torch.nn.Module
  • train_loader: torch.utils.data.DataLoader
  • val_loader: torch.utils.data.DataLoader
  • loggers: Optional[List[aurora.logging.BaseLogger]] (List of loggers that subscribe to various logging events, see logging section)
  • max_epochs: int (default 1)
  • use_gpu: bool (default True)
  • gpu_id: int (default 0)
  • resume_checkpoint: Optional[Union[str, pathlib.Path]] (Path to checkpoint directory to resume training from, default None)
  • save_path: Union[str, pathlib.Path] (Path to directory that blowtorch will save logs and checkpoints to, default 'train_logs')
  • run_name: Optional[str] (Name associated with that run, will be randomly created if None, default None)
  • optimize_metric: Optional[str] (train metric that will be used for optimization, will pick the first returned one if None, default None)
  • checkpoint_metric: Optional[str] (validation metric that will be used for checkpointing, will pick the first returned one if None, default None)
  • smaller_is_better: bool (default True)
  • optimize_first: bool (whether optimization should occur during the first epoch, default False)
  • detect_anomalies: bool (enable autograd anomaly detection, default False)

Logging

Blowtorch will create a folder with name "[timestamp]-[name]-[sequential integer]" for each run inside the run.save_path directory. Here it will save the runs's configuration, metrics, a model summary, checkoints as well as source code. Additional loggers can be added through Runs loggers parameter:

  • blowtorch.loggers.WandbLogger: Logs to Weights & Biases
  • blowtorch.loggers.TensorBoardLogger: Logs to TensorBoard

Custom loggers can be created by subclassing blowtorch.loggers.BaseLogger.

Decorators

Blowtorch uses the decorator syntax to specify parts of the training pipeline:

  • @run.train_step, @run.val_step: Specify train/val steps with one or two functions. Arguments: batch, model, is_validate, device, epoch
  • @run.train_epoch, @run.val_epoch: Specify whole train/val epoch, in case more flexibility for iteration/optimization is required. Arguments: data_loader, model, is_validate, optimizers
  • @run.configure_optimizers: Return optimizers and learning rate schedulers. Can either return a single optimizer object or a dictionary with multiple optimizers/schedulers. Arguments: model

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

blowtorch-0.5.4.tar.gz (16.8 kB view details)

Uploaded Source

Built Distribution

blowtorch-0.5.4-py3-none-any.whl (18.4 kB view details)

Uploaded Python 3

File details

Details for the file blowtorch-0.5.4.tar.gz.

File metadata

  • Download URL: blowtorch-0.5.4.tar.gz
  • Upload date:
  • Size: 16.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.62.1 CPython/3.8.11

File hashes

Hashes for blowtorch-0.5.4.tar.gz
Algorithm Hash digest
SHA256 194f05c56d02484b912242b1f4e9fdd373d7f122a1980d9462349f139d3dd896
MD5 c1d1d9f778bf70040b49be141aca5b08
BLAKE2b-256 597bd713bdaf37a3eb4fd8c2553329e592ed9f5c217f1e9e7257da9d18d5f412

See more details on using hashes here.

File details

Details for the file blowtorch-0.5.4-py3-none-any.whl.

File metadata

  • Download URL: blowtorch-0.5.4-py3-none-any.whl
  • Upload date:
  • Size: 18.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.62.1 CPython/3.8.11

File hashes

Hashes for blowtorch-0.5.4-py3-none-any.whl
Algorithm Hash digest
SHA256 f06daa127649ff37df9a73e0989ddbc4b45290e84ea3ec5e763feab0ae939546
MD5 e4af547c2a179f318effd5d7ab6513cb
BLAKE2b-256 62536d412e31bbdd3864ff0c845dc7b35ac7998a5017fab5deb8ef29d947c367

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page