Skip to main content
Python Software Foundation 20th Year Anniversary Fundraiser  Donate today!

Interactions between Dask and Tensorflow

Project description

Start TensorFlow clusters from Dask


Given a Dask cluster

from dask.distributed import Client
client = Client('scheduler-address:8786')

Get a TensorFlow cluster, specifying groups by name

from dask_tensorflow import start_tensorflow
tf_spec, dask_spec = start_tensorflow(client, ps=2, worker=4)

>>> tf_spec
{'worker': ['', '',
            '', ''],
 'ps': ['', '']}

This creates a tensorflow.train.Server on each Dask worker and sets up a Queue for data transfer on each worker. These are accessible directly as tensorflow_server and tensorflow_queue attributes on the workers.

More Complex Workflow

Typically then we set up long running Dask tasks that get these servers and participate in general TensorFlow compuations.

from dask.distributed import worker_client

def ps_function(self):
    with worker_client() as c:
        tf_server = c.worker.tensorflow_server

ps_tasks = [client.submit(ps_function, workers=worker, pure=False)
            for worker in dask_spec['ps']]

def worker_function(self):
    with worker_client() as c:
        tf_server = c.worker.tensorflow_server

        # ... use tensorflow as desired ...

worker_tasks = [client.submit(worker_function, workers=worker, pure=False)
                for worker in dask_spec['worker']]

One simple and flexible approach is to have these functions block on queues and feed them data from dask arrays, dataframes, etc.

def worker_function(self):
    with worker_client() as c:
        tf_server = c.worker.tensorflow_server
        queue = c.worker.tensorflow_queue

        while not stopping_condition():
            batch = queue.get()
            # train with batch

And then dump blocks of numpy and pandas dataframes to these queues

from distributed.worker_client import get_worker
def dump_batch(batch):
    worker = get_worker()

import dask.dataframe as dd
df = dd.read_csv('hdfs:///path/to/*.csv')
# clean up dataframe as necessary
partitions = df.to_delayed()  # delayed pandas dataframes, partitions)

Project details

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for dask-tensorflow, version 0.0.2
Filename, size File type Python version Upload date Hashes
Filename, size dask_tensorflow-0.0.2-py2.py3-none-any.whl (5.3 kB) File type Wheel Python version py2.py3 Upload date Hashes View
Filename, size dask-tensorflow-0.0.2.tar.gz (4.8 kB) File type Source Python version None Upload date Hashes View

Supported by

AWS AWS Cloud computing Datadog Datadog Monitoring DigiCert DigiCert EV certificate Facebook / Instagram Facebook / Instagram PSF Sponsor Fastly Fastly CDN Google Google Object Storage and Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Salesforce Salesforce PSF Sponsor Sentry Sentry Error logging StatusPage StatusPage Status page