Skip to main content

A simple library to manage Tensorflow experiments though git and reduce boilerplate. Compatible with tf 1.x

Project description

TfExperiment

A simple library to manage Tensorflow experiments though git and reduce boilerplate. Compatible with tf 1.x

Usage

This library relies of git to manage experiments. Each experiment should be a unique git branch, and the name of the experiment, if not give, will be the current git branch.

experiment = tfExperiment.Experiment(finalizeGraph = False)
experiment.saveGraph()

# output
# > graph location ======================================>
# > tensorboard --logdir  output\experimentName\graph

# with with
with experiment.trainingSession(epochs = 125, saveAfter = 2, testAfter = 2) as ts:
    ts.saveGraph() # function to save the graph
    ts.trainCallback = runTrainingCallback
    ts.testCallback = runTestCallback

# as function
experiment.train(runTrainingCallback)
experiment.test(runTestCallback)

API

__init__(name = None, finalizeGraph = False, location = os.path.join(os.getcwd(), 'output'))

  • name: string: Name of the experiment, if no name is provided the name of the current git branch will be used.

  • finalizeGraph: bool: Finalizes the graph. Attention I have not tried this feature much.

  • location: string: absolute path where the experiment results where saved in a folder with same name as name

train(trainCallback, epochs = 1, saveModelAfter = 2, saveGraph = False, testCallback = None, testAfter = 0)

Runs the training and validates/test the model

  • trainCallback: function: Function to be run at each epoch. This should contain your loop with the training actions to execute for each batch. The training callback can take 2 parameters: session (current tf.session), and env (if env is used you should use the exact name) experiment environment with access to functionalities like timer and dataSaver.

  • epochs: integer: Number of epochs to run, that is to say the number of times the traininCallbacks will be called. Attention: the experiment object keeps track of the number of epochs run so far, so if you call experiment.train again, the epoch number will continue to grow from the last epoch number.

  • saveModelAfter: integer: Save the model after n epochs have run. This only considers the current run.

  • saveGraph: bool: If we should save the graph at the current run.

  • testCallback: function: Function to call to test/validate the current network. Similar to trainCallback.

  • testAfter: integer: test the model after n epochs have run. This only considers the current run.

test(testCallback)

Runs the testing/validation of the model once

  • testCallback: function: Function to call to test/validate the current network. Similar to trainCallback.

env: Box object

The env object contains

  • env.training.currentEpoch: integer: number of epochs since the instance was initialized.
  • env.training.currentEpoch: integer: number of epochs since the instance was initialized.
  • env.training.dataSavePath: path string: path in which the data will be used if dataSaver is used during training.
  • env.training.dataSaver: dataSaver Instance: dataSaver instance (initialized with env.training.dataSavePath) for training to the training file.
  • env.testing.dataSavePath: path string path in which the data will be used if dataSaver is used during testing.
  • env.testing.dataSaver: dataSaver Instance: dataSaver instance (initialized with env.testing.dataSavePath) for testing to the training file.

Proposed New API

def TrainExperiment(Experiment):
    def __init__(self, constructor, ...):
        #someconfig
        #self.nrTotEpochs
        #self.epochsToValidateAfter
        #...

    def beforeEpoch
    def afterEpoch

    def beforeSave
    def afterSave

    def beforeTest
    def afterTest

    def beforeIteration
    def afterIteration

    def train(session, data, dataProvider = None):
        return 0 #trainingLoopPerSession

    def validate(session, data, dataProvider = None):
        return 0 #trainingLoopPerSession


experiment(TrainExperiment)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tfExperiment-1.1.7.tar.gz (8.3 kB view details)

Uploaded Source

Built Distribution

tfExperiment-1.1.7-py3-none-any.whl (8.7 kB view details)

Uploaded Python 3

File details

Details for the file tfExperiment-1.1.7.tar.gz.

File metadata

  • Download URL: tfExperiment-1.1.7.tar.gz
  • Upload date:
  • Size: 8.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.19.1 setuptools/39.1.0 requests-toolbelt/0.9.1 tqdm/4.27.0 CPython/3.6.7

File hashes

Hashes for tfExperiment-1.1.7.tar.gz
Algorithm Hash digest
SHA256 087b7a2dff6122ca19ee0d0bc7eab1501d7c45223bb7d85d669f23c110cf16f4
MD5 118d68908487418f69a43ba24c64f0ca
BLAKE2b-256 062efde37ce5b4eccfaa09f517f6bb3fbd707efabb5d2baf9076a2eb68946488

See more details on using hashes here.

File details

Details for the file tfExperiment-1.1.7-py3-none-any.whl.

File metadata

  • Download URL: tfExperiment-1.1.7-py3-none-any.whl
  • Upload date:
  • Size: 8.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.19.1 setuptools/39.1.0 requests-toolbelt/0.9.1 tqdm/4.27.0 CPython/3.6.7

File hashes

Hashes for tfExperiment-1.1.7-py3-none-any.whl
Algorithm Hash digest
SHA256 0f042e839ef20f3dfd92302044f79e28db1b48542959b7b9fadad67ec3ae2a31
MD5 f05de439a26907bacd50c1a7aa32bd97
BLAKE2b-256 a02091feee6e71a28d9508e0aaca6047493693361d9191d5c9461ccedb15386d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page