Skip to main content

A lightweight library to ease the training and the debugging of deep neural networks with PyTorch. Data structures and paradigms.

Project description

Booster

A lightweight library to ease the training and the debugging of deep neural networks with PyTorch. Data structures and paradigms.

Data Structures

Diagnostic

A two level dictionary structure to store the model diagnostics. Compatible with Tensorboard datastructure.

Example:

from booster import Diagnostic

data = {
'loss' : {'nll' : [45., 58.], 'kl': [22., 18.]},
'info' : {'batch_size' : 16, 'runtime' : 0.01}
}

diagnostic = Diagnostic(data)

Aggregator

A module to compute the running average of the diagnostics.

from booster import Aggregator, Diagnostic

aggregator = Aggregator()
...
aggregator.initialize()
for x in data_loader:
  data = optimization_step(model, data)
  aggregator.update(data)

summmary = aggregator.data # summary is an instance of Diagnostic
summmary = summary.to('cpu')

The output is a Diagnostic object and can easily be logged to Tensorboard.

# log to tensorboard
writer = SummaryWriter(log_dir="...")
summary.log(writer, global_step)

Evaluator

The Evaluator computes a forward pass through the model, the loss and additional Diagnostics.

from booster.evaluation import Classification
model = Classifier()
evaluator = Classification(categories=10)

# evaluate model
data = next(iter(loader))
loss, diagnostics, output = evaluator(model, data)

Pipeline: model + evaluator

The pipeline fuses the model forward pass with the evaluator and can be wrapped into a custom Dataparallel class that handles the diagnostics.

from booster import Pipeline, DataParallelPipeline

# fuse model + evaluator
pipeline = Pipeline(model, evaluator)

# wrap as DataParallel
parallel_pipeline = DataParallelPipeline(pipeline, device_ids=device_ids)

# evaluate model on multiple devices and gather loss and diagnostics
data = next(iter(loader))
loss, diagnostics, output = parallel_pipeline(data) 

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

booster-pytorch-0.0.2.tar.gz (14.2 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page