Skip to main content

Interactions between Dask and Tensorflow

Project description

Start TensorFlow clusters from Dask

Example

Given a Dask cluster

from dask.distributed import Client
client = Client('scheduler-address:8786')

Get a TensorFlow cluster, specifying groups by name

from dask_tensorflow import start_tensorflow
tf_spec, dask_spec = start_tensorflow(client, ps=2, worker=4)

>>> tf_spec
{'worker': ['192.168.1.100:2222', '192.168.1.101:2222',
            '192.168.1.102:2222', '192.168.1.103:2222'],
 'ps': ['192.168.1.104:2222', '192.168.1.105:2222']}

This creates a tensorflow.train.Server on each Dask worker and sets up a Queue for data transfer on each worker. These are accessible directly as tensorflow_server and tensorflow_queue attributes on the workers.

More Complex Workflow

Typically then we set up long running Dask tasks that get these servers and participate in general TensorFlow compuations.

from dask.distributed import worker_client

def ps_function(self):
    with worker_client() as c:
        tf_server = c.worker.tensorflow_server
        tf_server.join()

ps_tasks = [client.submit(ps_function, workers=worker, pure=False)
            for worker in dask_spec['ps']]

def worker_function(self):
    with worker_client() as c:
        tf_server = c.worker.tensorflow_server

        # ... use tensorflow as desired ...

worker_tasks = [client.submit(worker_function, workers=worker, pure=False)
                for worker in dask_spec['worker']]

One simple and flexible approach is to have these functions block on queues and feed them data from dask arrays, dataframes, etc.

def worker_function(self):
    with worker_client() as c:
        tf_server = c.worker.tensorflow_server
        queue = c.worker.tensorflow_queue

        while not stopping_condition():
            batch = queue.get()
            # train with batch

And then dump blocks of numpy and pandas dataframes to these queues

from distributed.worker_client import get_worker
def dump_batch(batch):
    worker = get_worker()
    worker.tensorflow_queue.put(batch)


import dask.dataframe as dd
df = dd.read_csv('hdfs:///path/to/*.csv')
# clean up dataframe as necessary
partitions = df.to_delayed()  # delayed pandas dataframes
client.map(dump_batch, partitions)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dask-tensorflow-0.0.2.tar.gz (4.8 kB view hashes)

Uploaded Source

Built Distribution

dask_tensorflow-0.0.2-py2.py3-none-any.whl (5.3 kB view hashes)

Uploaded Python 2 Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page