Skip to main content

This package helps users do distributed training with TensorFlow on their Spark clusters.

Project description

Spark TensorFlow Distributor

This package helps users do distributed training with TensorFlow on their Spark clusters.


This package requires Python 3.6+, tensorflow>=2.1.0 and pyspark>=3.0.0 to run. To install spark-tensorflow-distributor, run:

pip install spark-tensorflow-distributor

The installation does not install PySpark because for most users, PySpark is already installed. In addition, tensorflow not installed so that users may choose between regular and CPU-only installation via pip install tensorflow and pip install tensorflow-cpu. If you do not have PySpark installed, you can install it directly:

pip install pyspark>=3.0.*

Note also that in order to use many features of this package, you must set up Spark custom resource scheduling for GPUs on your cluster. See the Spark docs for this.

Running Tests

For integration tests, first build the master and worker images and then run the test script.

docker-compose build --build-arg PYTHON_INSTALL_VERSION=3.7

For linting, run the following.


To use the autoformatter, run the following.

yapf --recursive --in-place spark_tensorflow_distributor


Run following example code in pyspark shell:

from spark_tensorflow_distributor import MirroredStrategyRunner

# Adapted from
def train():
    import tensorflow as tf
    import uuid

    BUFFER_SIZE = 10000
    BATCH_SIZE = 64

    def make_datasets():
        (mnist_images, mnist_labels), _ = \

        dataset =
            tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32),
            tf.cast(mnist_labels, tf.int64))
        dataset = dataset.repeat().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
        return dataset

    def build_and_compile_cnn_model():
        model = tf.keras.Sequential([
            tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dense(10, activation='softmax'),
        return model

    train_datasets = make_datasets()
    options =
    options.experimental_distribute.auto_shard_policy =
    train_datasets = train_datasets.with_options(options)
    multi_worker_model = build_and_compile_cnn_model(), epochs=3, steps_per_epoch=5)


Project details

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

spark_tensorflow_distributor-1.0.0.tar.gz (9.2 kB view hashes)

Uploaded source

Built Distribution

Supported by

AWS AWS Cloud computing Datadog Datadog Monitoring Facebook / Instagram Facebook / Instagram PSF Sponsor Fastly Fastly CDN Google Google Object Storage and Download Analytics Huawei Huawei PSF Sponsor Microsoft Microsoft PSF Sponsor NVIDIA NVIDIA PSF Sponsor Pingdom Pingdom Monitoring Salesforce Salesforce PSF Sponsor Sentry Sentry Error logging StatusPage StatusPage Status page