Skip to main content

A lightweight library to ease the training and the debugging of deep neural networks with PyTorch. Data structures and paradigms.

Project description

Booster

A lightweight library to ease the training and the debugging of deep neural networks with PyTorch. Data structures and paradigms.

Data Structures

Diagnostic

A two level dictionary structure to store the model diagnostics. Compatible with Tensorboard datastructure.

Example:

from booster import Diagnostic

data = {
'loss' : {'nll' : [45., 58.], 'kl': [22., 18.]},
'info' : {'batch_size' : 16, 'runtime' : 0.01}
}

diagnostic = Diagnostic(data)

Aggregator

A module to compute the running average of the diagnostics.

from booster import Aggregator, Diagnostic

aggregator = Aggregator()
...
aggregator.initialize()
for x in data_loader:
  data = optimization_step(model, data)
  aggregator.update(data)

summmary = aggregator.data # summary is an instance of Diagnostic
summmary = summary.to('cpu')

The output is a Diagnostic object and can easily be logged to Tensorboard.

# log to tensorboard
writer = SummaryWriter(log_dir="...")
summary.log(writer, global_step)

Evaluator

The Evaluator computes a forward pass through the model, the loss and additional Diagnostics.

from booster.evaluation import Classification
model = Classifier()
evaluator = Classification(categories=10)

# evaluate model
data = next(iter(loader))
loss, diagnostics, output = evaluator(model, data)

Pipeline: model + evaluator

The pipeline fuses the model forward pass with the evaluator and can be wrapped into a custom Dataparallel class that handles the diagnostics.

from booster import Pipeline, DataParallelPipeline

# fuse model + evaluator
pipeline = Pipeline(model, evaluator)

# wrap as DataParallel
parallel_pipeline = DataParallelPipeline(pipeline, device_ids=device_ids)

# evaluate model on multiple devices and gather loss and diagnostics
data = next(iter(loader))
loss, diagnostics, output = parallel_pipeline(data) 

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

booster-pytorch-0.0.2.tar.gz (14.2 kB view details)

Uploaded Source

File details

Details for the file booster-pytorch-0.0.2.tar.gz.

File metadata

  • Download URL: booster-pytorch-0.0.2.tar.gz
  • Upload date:
  • Size: 14.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.4.0 requests-toolbelt/0.9.1 tqdm/4.36.1 CPython/3.7.4

File hashes

Hashes for booster-pytorch-0.0.2.tar.gz
Algorithm Hash digest
SHA256 b773b4afb42db509a8417157ee85c2af7f240ab09bec5e958a7ea69b601c1ff9
MD5 8c42604367c2e01e9c43fa3e3029adad
BLAKE2b-256 0d964725f89b367c3649dac69af0de4ee2d230cd695d85038cae02d9d2cc039a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page