Skip to main content
Help us improve PyPI by participating in user testing. All experience levels needed!

Interactions between Dask and Tensorflow

Project description

Start TensorFlow clusters from Dask


Given a Dask cluster

from dask.distributed import Client
client = Client('scheduler-address:8786')

Get a TensorFlow cluster, specifying groups by name

from dask_tensorflow import start_tensorflow
tf_spec, dask_spec = start_tensorflow(client, ps=2, worker=4)

>>> tf_spec
{'worker': ['', '',
            '', ''],
 'ps': ['', '']}

This creates a tensorflow.train.Server on each Dask worker and sets up a Queue for data transfer on each worker. These are accessible directly as tensorflow_server and tensorflow_queue attributes on the workers.

More Complex Workflow

Typically then we set up long running Dask tasks that get these servers and participate in general TensorFlow compuations.

from dask.distributed import worker_client

def ps_function(self):
    with worker_client() as c:
        tf_server = c.worker.tensorflow_server

ps_tasks = [client.submit(ps_function, workers=worker, pure=False)
            for worker in dask_spec['ps']]

def worker_function(self):
    with worker_client() as c:
        tf_server = c.worker.tensorflow_server

        # ... use tensorflow as desired ...

worker_tasks = [client.submit(worker_function, workers=worker, pure=False)
                for worker in dask_spec['worker']]

One simple and flexible approach is to have these functions block on queues and feed them data from dask arrays, dataframes, etc.

def worker_function(self):
    with worker_client() as c:
        tf_server = c.worker.tensorflow_server
        queue = c.worker.tensorflow_queue

        while not stopping_condition():
            batch = queue.get()
            # train with batch

And then dump blocks of numpy and pandas dataframes to these queues

from distributed.worker_client import get_worker
def dump_batch(batch):
    worker = get_worker()

import dask.dataframe as dd
df = dd.read_csv('hdfs:///path/to/*.csv')
# clean up dataframe as necessary
partitions = df.to_delayed()  # delayed pandas dataframes, partitions)

Project details

Release history Release notifications

This version
History Node


History Node


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Filename, size & hash SHA256 hash help File type Python version Upload date
dask_tensorflow-0.0.2-py2.py3-none-any.whl (5.3 kB) Copy SHA256 hash SHA256 Wheel py2.py3 Jan 10, 2018
dask-tensorflow-0.0.2.tar.gz (4.8 kB) Copy SHA256 hash SHA256 Source None Jan 10, 2018

Supported by

Elastic Elastic Search Pingdom Pingdom Monitoring Google Google BigQuery Sentry Sentry Error logging CloudAMQP CloudAMQP RabbitMQ AWS AWS Cloud computing Fastly Fastly CDN DigiCert DigiCert EV certificate StatusPage StatusPage Status page