A lightweight library to ease the training and the debugging of deep neural networks with PyTorch. Data structures and paradigms.
Project description
Booster
A lightweight library to ease the training and the debugging of deep neural networks with PyTorch. Data structures and paradigms.
Data Structures
Diagnostic
A two level dictionary structure to store the model diagnostics. Compatible with Tensorboard datastructure.
Example:
from booster import Diagnostic
data = {
'loss' : {'nll' : [45., 58.], 'kl': [22., 18.]},
'info' : {'batch_size' : 16, 'runtime' : 0.01}
}
diagnostic = Diagnostic(data)
Aggregator
A module to compute the running average of the diagnostics.
from booster import Aggregator, Diagnostic
aggregator = Aggregator()
...
aggregator.initialize()
for x in data_loader:
data = optimization_step(model, data)
aggregator.update(data)
summmary = aggregator.data # summary is an instance of Diagnostic
summmary = summary.to('cpu')
The output is a Diagnostic object and can easily be logged to Tensorboard.
# log to tensorboard
writer = SummaryWriter(log_dir="...")
summary.log(writer, global_step)
Evaluator
The Evaluator computes a forward pass through the model, the loss and additional Diagnostics.
from booster.evaluation import Classification
model = Classifier()
evaluator = Classification(categories=10)
# evaluate model
data = next(iter(loader))
loss, diagnostics, output = evaluator(model, data)
Pipeline: model + evaluator
The pipeline fuses the model forward pass with the evaluator and can be wrapped into a custom Dataparallel class that handles the diagnostics.
from booster import Pipeline, DataParallelPipeline
# fuse model + evaluator
pipeline = Pipeline(model, evaluator)
# wrap as DataParallel
parallel_pipeline = DataParallelPipeline(pipeline, device_ids=device_ids)
# evaluate model on multiple devices and gather loss and diagnostics
data = next(iter(loader))
loss, diagnostics, output = parallel_pipeline(data)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file booster-pytorch-0.0.2.tar.gz
.
File metadata
- Download URL: booster-pytorch-0.0.2.tar.gz
- Upload date:
- Size: 14.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.4.0 requests-toolbelt/0.9.1 tqdm/4.36.1 CPython/3.7.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b773b4afb42db509a8417157ee85c2af7f240ab09bec5e958a7ea69b601c1ff9 |
|
MD5 | 8c42604367c2e01e9c43fa3e3029adad |
|
BLAKE2b-256 | 0d964725f89b367c3649dac69af0de4ee2d230cd695d85038cae02d9d2cc039a |