Skip to main content

A collection of functions to help you easily train and run Tensorflow Keras

Project description

Keras Toolkit

A collection of functions to help you easily train and run Tensorflow Keras

Get the complete API reference here.

Quickstart

Install the library:

pip install keras-toolkit

You can now use it:

import keras_toolkit as kt

# kt reduces the number of lines from ~100 to ~3
strategy = kt.accelerator.auto_select(verbose=True)
decoder = kt.image.build_decoder(with_labels=True, target_size=(300, 300))
dtrain = kt.image.build_dataset(paths, labels, bsize=BATCH_SIZE, decode_fn=decoder)

with strategy.scope():
    model = tf.keras.Sequential([...])
    model.compile(...)

model.fit(...)

Usage

To automatically select an accelerator (e.g. TPU, GPU, CPU) and run on that accelerator:

import keras_toolkit as kt
strategy = kt.accelerator.auto_select(verbose=True)

with strategy.scope():
    # your keras code here
    model = tf.keras.Sequential([...])

To restrict the GPU memory usage of TensorFlow (e.g. to 2GB):

import keras_toolkit as kt

kt.accelerator.limit_gpu_memory(2*1024)

To build an image dataset from a list of paths and a list of labels (associated with the paths):

import keras_toolkit as kt

dtrain = kt.image.build_dataset(paths, labels)
# => <PrefetchDataset shapes: ((None, 256, 256, 3), (None,)), types: (tf.float32, tf.int32)>

# Fit your keras model on that new tf.data.Dataset:
model.fit(dtrain, ...)

If you only have a list of image paths, it will create tf.data.Dataset without labels:

dtrain = kt.image.build_dataset(paths)
# => <PrefetchDataset shapes: (None, 256, 256, 3), types: tf.float32>

You can also customize the dataset (e.g. batch size, custom image loader, custom augmentation, etc.):

# This is just the default
img_decoder = kt.image.build_decoder(target_size=(512, 512))
augmenter = kt.image.build_augmenter()

dset = build_dataset(
    paths, labels, 
    decode_fn=img_decoder,
    bsize=64,
    cache="./cache_dir/",
    augment=augmenter,
    shuffle=False,
    random_state=42
)

Acknowledgement

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras-toolkit-0.1.0rc6.tar.gz (5.4 kB view hashes)

Uploaded Source

Built Distribution

keras_toolkit-0.1.0rc6-py3-none-any.whl (5.9 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page