Skip to main content

Manage training results, weights and data flow of your Tensorflow models

Project description

MLPipe-Trainer

Manage your Data Pipline and Tensorflow & Keras models with MLPipe. It is NOT another "wrapper" around Tensorflow, but rather adds utilities to setup an environment to control data flow and managed trained models (weights & results) with the help of MongoDB.

>> pip install mlpipe-trainer

Setup - install MongoDB

MongoDB database is used to store trained Models including their weights and results. Additionally there is also a data reader for MongoDB implemented (basically just a generator as you know and love from using keras). Currenlty that is the only implemented data reader working "out of the box".
Follow the instructions on the MongoDB website for installation e.g. for Linux: https://docs.mongodb.com/manual/administration/install-on-linux/

Code Examples

Config

# The config is used to specify the localhost connections
# for saving trained models to the mongoDB as well as fetching training data
from mlpipe.utils import Config
Config.add_config('./path_to/config.ini')

Each Connection config consists of these fields in the .ini file

[example_mongo_db_connection]
db_type=MongoDB
url=localhost
port=27017
user=read_write
pwd=rw

Data Pipline

from mlpipe.processors.i_processor import IPreProcessor
from mlpipe.data_reader.mongodb import MongoDBGenerator

class PreProcessData(IPreProcessor):
    def process(self, raw_data, input_data, ground_truth, piped_params=None):
        # Process raw_data to output input_data and ground_truth
        # which will be the input for the model
        ...
        return raw_data, input_data, ground_truth, piped_params

train_data = [...]  # consists of MongoDB ObjectIds that are used for training
processors = [PreProcessData()]  # Chain of Processors (in our case its just one)
# Generator that can be used e.g. with keras' fit_generator()
train_gen = MongoDBGenerator(
    ("connection_name", "cifar10", "train"),  # specify data source from a MongoDB
    train_data,
    batch_size=128,
    processors=processors
)

Data generators inherit from tf.keras.utils.Sequence. Check out this tensorflow docu to find out how you can write your custom generators (e.g. for other data sources than MongoDB).

Model

As long as there is a keras (tensorflow.keras) model in the end, there are no restrictions on this step

model = Sequential()
model.add(Conv2D(32, (3, 3), padding='same', input_shape=(32, 32, 3)))
...
model.add(Dense(10, activation='softmax'))

opt = optimizers.RMSprop(lr=0.0001, decay=1e-6)
model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=["accuracy"])

Training and Callbacks

from mlpipe.callbacks import SaveToMongoDB

save_to_mongodb_cb = SaveToMongoDB(("localhost_mongo_db", "models"), "test", model)

model.fit_generator(
    generator=train_gen,
    validation_data=val_gen,
    epochs=10,
    verbose=1,
    callbacks=[save_to_mongodb_cb],
    initial_epoch=0,
)

SaveToMongoDB is a custom keras callback class as described in the tensorflow docu. Again, feel free to create custom callbacks for any specific needs.
If, instead of fit_generator(), each batch is trained one-by-one e.g. with a native tensorflow model, you can still loop over the generator. Just remember to call the callback methods at the specific steps e.g. on_batch_end().

A full Cifar10 example can be found in the example folder here

Road Map

  • Create and generat MkDocs documentation & host documentation
  • Add tests
  • Set Up CI

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlpipe-trainer-0.4.8.tar.gz (16.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mlpipe_trainer-0.4.8-py3-none-any.whl (22.0 kB view details)

Uploaded Python 3

File details

Details for the file mlpipe-trainer-0.4.8.tar.gz.

File metadata

  • Download URL: mlpipe-trainer-0.4.8.tar.gz
  • Upload date:
  • Size: 16.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.19.1 setuptools/40.2.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.7.2

File hashes

Hashes for mlpipe-trainer-0.4.8.tar.gz
Algorithm Hash digest
SHA256 be912bab94ee24c12d7464d15be59f24fc619535879b9524a77f975b4104ad6c
MD5 3963b2b1acb4e10ed6e8fb39222bc603
BLAKE2b-256 9bcc9fd8359e05794db6bb2776785aa6cdace700bac77628d04c012af52d9ea9

See more details on using hashes here.

File details

Details for the file mlpipe_trainer-0.4.8-py3-none-any.whl.

File metadata

  • Download URL: mlpipe_trainer-0.4.8-py3-none-any.whl
  • Upload date:
  • Size: 22.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.19.1 setuptools/40.2.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.7.2

File hashes

Hashes for mlpipe_trainer-0.4.8-py3-none-any.whl
Algorithm Hash digest
SHA256 53a48eaa834d67f9d7ae04e9d044cba0dadac5752473e31a48b1fd58ad61814f
MD5 65f45a942291e074c537e6a0c9ac21d6
BLAKE2b-256 333a0244d6d940ff09b006df76075243568b0e0e7085e3e282e1229e4f69b478

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page