Skip to main content

Interactions between Dask and Tensorflow

Project description

Start TensorFlow clusters from Dask

Example

Given a Dask cluster

from dask.distributed import Client
client = Client('scheduler-address:8786')

Get a TensorFlow cluster, specifying groups by name

from dask_tensorflow import start_tensorflow
tf_spec, dask_spec = start_tensorflow(client, ps=2, worker=4)

>>> tf_spec
{'worker': ['192.168.1.100:2222', '192.168.1.101:2222',
            '192.168.1.102:2222', '192.168.1.103:2222'],
 'ps': ['192.168.1.104:2222', '192.168.1.105:2222']}

This creates a tensorflow.train.Server on each Dask worker and sets up a Queue for data transfer on each worker. These are accessible directly as tensorflow_server and tensorflow_queue attributes on the workers.

More Complex Workflow

Typically then we set up long running Dask tasks that get these servers and participate in general TensorFlow compuations.

from dask.distributed import worker_client

def ps_function(self):
    with worker_client() as c:
        tf_server = c.worker.tensorflow_server
        tf_server.join()

ps_tasks = [client.submit(ps_function, workers=worker, pure=False)
            for worker in dask_spec['ps']]

def worker_function(self):
    with worker_client() as c:
        tf_server = c.worker.tensorflow_server

        # ... use tensorflow as desired ...

worker_tasks = [client.submit(worker_function, workers=worker, pure=False)
                for worker in dask_spec['worker']]

One simple and flexible approach is to have these functions block on queues and feed them data from dask arrays, dataframes, etc.

def worker_function(self):
    with worker_client() as c:
        tf_server = c.worker.tensorflow_server
        queue = c.worker.tensorflow_queue

        while not stopping_condition():
            batch = queue.get()
            # train with batch

And then dump blocks of numpy and pandas dataframes to these queues

from distributed.worker_client import get_worker
def dump_batch(batch):
    worker = get_worker()
    worker.tensorflow_queue.put(batch)


import dask.dataframe as dd
df = dd.read_csv('hdfs:///path/to/*.csv')
# clean up dataframe as necessary
partitions = df.to_delayed()  # delayed pandas dataframes
client.map(dump_batch, partitions)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dask-tensorflow-0.0.2.tar.gz (4.8 kB view details)

Uploaded Source

Built Distribution

dask_tensorflow-0.0.2-py2.py3-none-any.whl (5.3 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file dask-tensorflow-0.0.2.tar.gz.

File metadata

File hashes

Hashes for dask-tensorflow-0.0.2.tar.gz
Algorithm Hash digest
SHA256 9687ca632ca9769ccc2c8f99300ebfc675b19ef6369a57e33c00e7178f412d2e
MD5 0a27b2144c4f550cafd6ddb98a29d018
BLAKE2b-256 9c6eb09e3fa80760aef4bb77ed5cee6df94ef21dec347da6be0fea9505f9d792

See more details on using hashes here.

File details

Details for the file dask_tensorflow-0.0.2-py2.py3-none-any.whl.

File metadata

File hashes

Hashes for dask_tensorflow-0.0.2-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 b88b6f330f846bb1a9d58cfdd612ca5a3d3ef4874e3e81405395e6104d8ad2d9
MD5 944c7b4c6cde97705bc54036fff531ff
BLAKE2b-256 00a56cd58713aacf16fc8ef801e3020894a1faba7710c19c047c3e9582081b20

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page