Skip to main content

Full knowledge and control of the train state

Project description

DeepTrain

Build Status Coverage Status Codacy Badge PyPI version Documentation Status

License: MIT

Full knowledge and control of the train state.

Features

DeepTrain is founded on control and introspection: full knowledge and manipulation of the train state.

Train Loop

  • Resumability: interrupt-protection, can pause mid-training
  • Tracking & reproducibility: save & load model, train state, random seeds, and hyperparameter info

Data Pipeline

  • Flexible batch_size: can differ from that of loaded files, will split/combine (ex)
  • Faster SSD loading: load larger batches to maximize read speed utility
  • Stateful timeseries: splits up a batch into windows, and reset_states() (RNNs) at end (ex)

Introspection & Utilities

  • Model: auto descriptive naming (ex); gradients, weights, activations visuals (ex)
  • Train state: image log of key attributes for easy reference (ex); batches marked w/ "set nums" - know what's being fit and when
  • Algorithms, preprocesing, calibration: tools for inspecting & manipulating data and models

Complete list

When is DeepTrain suitable (and not)?

Training few models thoroughly: closely tracking model and train attributes to debug performance and inform next steps.

DeepTrain is not for models that take under an hour to train, or for training hundreds of models at once.

What does DeepTrain do?

Abstract away boilerplate train loop and data loading code, without making it into a black box. Code is written intuitively and fully documented. Everything about the train state can be seen via dedicated attributes; which batch is being fit and when, how long until an epoch ends, intermediate metrics, etc.

DeepTrain is not a "wrapper" around TF; while currently only supporting TF, fitting and data logic is framework-agnostic.

How it works

  1. We define tg = TrainGenerator(**configs),
  2. call tg.train().
  3. get_data() is called, returning data & labels,
  4. fed to model.fit(), returning metrics,
  5. which are then printed, recorded.
  6. The loop repeats, or validate() is called.

Once validate() finishes, training may checkpoint, and train() is called again. Internally, data loads with DataGenerator.load_data() (using e.g. np.load).

That's the high-level overview; details here. Callbacks & other behavior can be configured for every stage of training.

Examples

MNIST AutoEncoder Timeseries Classification Health Monitoring
Tracking Weights Reproducibility Flexible batch_size

Installation

pip install deeptrain (without data; see how to run examples), or clone repository

Quickstart

To run, DeepTrain requires (1) a compiled model; (2) data directories (train & val). Below is a minimalistic example.

Checkpointing, visualizing, callbacks & more can be accomplished via additional arguments; see Basic and Advanced examples. Also see Recommended Usage.

from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
from deeptrain import TrainGenerator, DataGenerator

ipt = Input((16,))
out = Dense(10, 'softmax')(ipt)
model = Model(ipt, out)
model.compile('adam', 'categorical_crossentropy')

dg  = DataGenerator(data_path="data/train", labels_path="data/train/labels.npy")
vdg = DataGenerator(data_path="data/val",   labels_path="data/val/labels.npy")
tg  = TrainGenerator(model, dg, vdg, epochs=3, logs_dir="logs/")

tg.train()

In future releases

  • MetaTrainer: direct support for dynamic model recompiling with changing hyperparameters, and optimizing thereof
  • PyTorch support

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

deeptrain-0.6.0.tar.gz (355.6 kB view details)

Uploaded Source

Built Distribution

deeptrain-0.6.0-py3-none-any.whl (364.1 kB view details)

Uploaded Python 3

File details

Details for the file deeptrain-0.6.0.tar.gz.

File metadata

  • Download URL: deeptrain-0.6.0.tar.gz
  • Upload date:
  • Size: 355.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.2.0.post20200714 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.7

File hashes

Hashes for deeptrain-0.6.0.tar.gz
Algorithm Hash digest
SHA256 58a6373b0ece114e674e1eb7e2a44d5ece0466e227c5b62aff00faeffe3ed933
MD5 507f38d011974438ef7207c0310c6c2f
BLAKE2b-256 1b4b9bfdaaf161e326b25676194f5e497ebd857bbc1cbdf8204693cd0bb14a44

See more details on using hashes here.

File details

Details for the file deeptrain-0.6.0-py3-none-any.whl.

File metadata

  • Download URL: deeptrain-0.6.0-py3-none-any.whl
  • Upload date:
  • Size: 364.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.2.0.post20200714 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.7

File hashes

Hashes for deeptrain-0.6.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f149b0f0ef31ed7847ce44b243e4af151946510d0278f653662ce86f280c0fe6
MD5 2aafd74518a068f798ba8ffd4ad0c4b7
BLAKE2b-256 fa7a656d194beb08f000996f2f41d67b72f5ac81cc52f0d631a6281a2ec0a6cf

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page