Skip to main content

A deep learning project template for TensorFlow

Project description


A deep learning project template for TensorFlow. Utilities for TFDataset handling and hyperparameter optimization.

Getting Started

Fork the repository, navigate to the local directory and follow the quick start checklist below. Data should be divided into training, validation and testing sets and placed in .tfrecords file format in separate directories. Utilities for converting NumPy arrays to .tfrecords files are included in the utils.py module.

Quick Start Checklist

  • Build neural network models in models.py
    • Note: hyperparameters that will be optimized later need to be specified. See Model Selection section below.
  • Connect models, define loss, evaluation metrics, optimizer, etc. in graph.py
    • Note: hyperparameters that will be optimized later need to be specified. See Model Selection section below.
  • Specify input data shapes in data_shapes dictionary
    • Note: Data shapes should not include batch size—this information is passed to the sampler instead
    • Additional note: the order in data_shapes should correspond to the order you retrieve the tensors in the graph
      • In main.py: data_shapes = {'lowres': (32, 32, 3), 'highres': (128, 128, 3)}
      • In graph.py: lowres, highres = data.get_batch()
  • Create a DataSampler object with the filepaths to your train/valid/test sets and the data shapes dictionary.
  • Define hyperparameters to optimize over and their corresponding domain ranges in dictionary
  • Pass your graph object and hyperparameter dictionary to Sherpa via gilgalad.opt.bayesian_optimization
  • (Optional) Add custom plotting functionality in plotting.py. Send any matplotlib figure to TensorBoard via tfplot.

Model Selection

Gil-Galad model selection is employed via Sherpa's Bayesian optimization suite which utilizes sklearn's Gaussian process module. Bayesian optimization specifies a distribution over functions via a kernel function and prior. Here, the mean function corresponds to a surrogate objective function whose predictor variables are the model hyperparameters. The prior distribution over functions is updated via Bayes' rule to account for trial runs wherein the independent variables specify the model and the dependent variable is the evaluation of such a model on the validation dataset.

With Gil-Galad, we specify which hyperparameters we will optimize by passing a parameter dictionary to our graph class while also defining default hyperparameters during graph and model construction as follows:

Graph-level

class Graph(BaseGraph):

    def __init__(self, network, sampler, logdir=None, ckptdir=None):

        self.network = network
        self.data = sampler
        self.logdir = logdir
        self.ckptdir = ckptdir

        self.build_graph()

    def build_graph(self, params=None):

        tf.reset_default_graph()
        self.data.initialize()

        self.x, self.y, self.z = self.data.get_batch()

        self.y_ = self.network(self.x, params=params)

        self.loss = tf.losses.mean_squared_error(self.y, self.y_)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            opt = tf.train.AdamOptimizer(
                learning_rate=params['lr'] if params else 0.001
            )

            self.update = opt.minimize(
                loss=self.loss,
                var_list=self.network.vars,
                global_step=self.global_step
            )

Model-level

class Model(BaseModel):
  def __init__(self, name):
    self.name = name

  def __call__(self, x, params):
    with tf.variable_scope(self.name) as vs:
      y = conv_2d(
        x=x,
        filters=params['filters'] if params else 64,
        kernel_size=params['kernel_size'] if params else 3,
        strides=2,
        activation=params['activation'] if params else 'relu'
      )

      return y

We then define the hyperparameter domain type and ranges in a dictionary. This information accompanies the graph object as arguments for the Bayesian optimization function.

import gilgalad as gg

hyperparameters = {
    'Discrete':
        {'filters': [64, 128],
         'kernel_size': [3, 5]},
    'Continuous':
        {'lr': [1e-5, 1e-3]},
    'Choice':
        {'activation': ['relu', 'prelu']}
}

best_model = gg.opt.bayesian_optimization(
    graph=graph,
    params=hyperparameters,
    max_trials=50
)

Project details


Release history Release notifications | RSS feed

This version

0.1

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

gil_galad-0.1-py3-none-any.whl (23.6 kB view details)

Uploaded Python 3

Gil_Galad-0.1-py3-none-any.whl (23.6 kB view details)

Uploaded Python 3

File details

Details for the file gil_galad-0.1-py3-none-any.whl.

File metadata

  • Download URL: gil_galad-0.1-py3-none-any.whl
  • Upload date:
  • Size: 23.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.19.1 setuptools/40.8.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.6.5

File hashes

Hashes for gil_galad-0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 acb1f343e3b66dc2c7775c94596f1b5b8b2e4c9c9bce45095c545dfda205c3c5
MD5 ec24a8d253d769878131371b767ada50
BLAKE2b-256 056bf98b15843e20f39a03c8c6b4838e30cd33a07175354b33e9e81acb1427a4

See more details on using hashes here.

File details

Details for the file Gil_Galad-0.1-py3-none-any.whl.

File metadata

  • Download URL: Gil_Galad-0.1-py3-none-any.whl
  • Upload date:
  • Size: 23.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.19.1 setuptools/40.8.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.6.5

File hashes

Hashes for Gil_Galad-0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 aa87e5cfca12370466b396435d7ff099816db11bb7f36d317248b067a4d3f856
MD5 fd67270676f8c93ccc45c18953b19f4e
BLAKE2b-256 58dd4061a3b035e9d2109ca50d6637d53b210c1793b8cc67c84364c57b040456

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page