Skip to main content

Delve lets you monitor PyTorch model layer saturation during training

Project description

Delve: Deep Live Visualization and Evaluation logo

PyPI version Build Status License: MIT

Delve is a Python package for visualizing deep learning model training.

[playground(https://github.com/justinshenk/playground)

Use Delve if you need a lightweight PyTorch or Keras extension that:

  • Plots live statistics of network layer inputs to TensorBoard or terminal
  • Performs spectral analysis to identify layer saturation for network pruning
  • Is easily extendible and configurable

Motivation

Designing a deep neural network involves optimizing over a wide range of parameters and hyperparameters. Delve allows you to visualize your layer saturation during training so you can grow and shrink layers as needed.

Demo

live layer saturation demo

example_fc.gif

Getting Started

pip install delve

Layer Saturation

PyTorch

delve.CheckLayerSat can be configured as follows:

logging_dir (str)  : destination for summaries
modules (torch modules or list of modules) : layer-containing object
log_interval (int) : steps between writing summaries
stats (list of str): list of stats to collect

    supported stats are:
        lsat       : layer saturation
        bcov       : batch covariance
        eigendist  : eigenvalue distribution
        neigendist : normalized eigenvalue distribution
        spectrum   : top-N eigenvalues of covariance matrix
        spectral   : spectral analysis (eigendist, neigendist, and spectrum)

sat_method         : Method for calculating saturation. Use `cumvar99`, `simpson_di`, or `all`.
                        See https://github.com/justinshenk/playground for a comparison of how they work.
include_conv       : bool, setting to False includes only linear layers
verbose (bool)     : print saturation for every layer during training

Pass either a PyTorch model or torch.nn.Linear layers to CheckLayerSat:

from delve import CheckLayerSat

model = TwoLayerNet() # PyTorch network
stats = CheckLayerSat('runs', model) #logging directory and input

... # setup data loader

for i, data in enumerate(train_loader):    
    stats.saturation() # output saturation

Only fully-connected and convolutional layers are currently supported.

To log the saturation to console, call stats.saturation(). For example:

Regression - SixLayerNet - Hidden layer size 10                        loss=0.231825:  68%|████████████████████▎         | 1350/2000 [00:04<00:02, 289.30it/s]│
linear1:  90%|█████████████████████████████████▎   | 90.0/100 [00:00<00:00, 453.47it/s]│
linear2:  18%|██████▊                               | 18.0/100 [00:00<00:00, 90.68it/s]│
linear3:  32%|███████████▊                         | 32.0/100 [00:00<00:00, 161.22it/s]│
linear4:  32%|███████████▊                         | 32.0/100 [00:00<00:00, 161.24it/s]│
linear5:  28%|██████████▎                          | 28.0/100 [00:00<00:00, 141.11it/s]│
linear6:  90%|██████████████████████████████████▏   | 90.0/100 [00:01<00:00, 56.04it/s]

Keras

Two classes are provided in delve.kerascallback: CustomTensorBoard,SaturationLogger.

CustomTensorBoard takes two parameters:

Argument Description
log_dir location for writing summaries
user_defined_freq frequency for writing summaries
kwargs passed to tf.keras.callbacks.TensorBoard

SaturationLogger contains two parameters:

Argument Description
model Keras model
input_data data for passing through the model
print_freq frequency for printing

Example usage:

from delve.kerascallback import CustomTensorBoard, SaturationLogger

...

# Tensorboard logging
tbCallBack = CustomTensorBoard(log_dir='./runs', user_defined_freq=1)

# Console logging
saturation_logger = SaturationLogger(model, input_data=input_x_train[:2], print_freq=1)

...

# Add callback to Keras `fit` method
model.fit(x_train, y_train,
          epochs=100,
          batch_size=128,
          callbacks=[saturation_logger]) # can also pass tbCallBack

Output:

Epoch 29/100
 128/1000 [==>...........................] - ETA: 0s - loss: 2.2783 - acc: 0.1406
dense_1  : %0.83 | dense_2  : %0.79 | dense_3  : %0.67 |

Optimize neural network topology

Ever wonder how big your fully-connected layers should be? Delve helps you visualize the effect of modifying the layer size on your layer saturation.

For example, see how modifying the hidden layer size of this network affects the second layer saturation but not the first. Multiple runs show that the fully-connected "linear2" layer (light blue is 256-wide and orange is 8-wide) saturation is sensitive to layer size:

saturation

saturation

Log spectral analysis

Writes the top 5 eigenvalues of each layer to TensorBoard summaries:

# PyTorch-only
stats = CheckLayerSat('runs', layers, 'spectrum')

Other options spectrum

Intrinsic dimensionality

View the intrinsic dimensionality of models in realtime:

intrinsic_dimensionality-layer2

This comparison suggests that the 8-unit layer (light blue) is too saturated and that a larger layer is needed.

Why this name, Delve?

delve (verb):

  • reach inside a receptacle and search for something
  • to carry on intensive and thorough research for data, information, or the like

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

delve-0.1.19.tar.gz (15.0 kB view details)

Uploaded Source

Built Distribution

delve-0.1.19-py2.py3-none-any.whl (14.5 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file delve-0.1.19.tar.gz.

File metadata

  • Download URL: delve-0.1.19.tar.gz
  • Upload date:
  • Size: 15.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.5.0.1 requests/2.7.0 setuptools/40.8.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.7.2

File hashes

Hashes for delve-0.1.19.tar.gz
Algorithm Hash digest
SHA256 8e950306c75cf65a07435737257a4802aa65a2e96ab7129b3380046f2b181788
MD5 68b28fa5033a66f066efa84748096c57
BLAKE2b-256 1144efee0d21f91bf66e5c5cf08b7c6b8f6125a980fda3ff359f83f5ff733ac5

See more details on using hashes here.

Provenance

File details

Details for the file delve-0.1.19-py2.py3-none-any.whl.

File metadata

  • Download URL: delve-0.1.19-py2.py3-none-any.whl
  • Upload date:
  • Size: 14.5 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.5.0.1 requests/2.7.0 setuptools/40.8.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.7.2

File hashes

Hashes for delve-0.1.19-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 3cbed8cabd647e4e26a5ebc66aa45c884113c4c4a5f057b1aa7f718f35146e7a
MD5 60f86355681331ce4d5a16749490cad9
BLAKE2b-256 32c04460ed2aa33f402b362644561e8bf8b2a98a3577340d2aab68c8cbda6155

See more details on using hashes here.

Provenance

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page