Skip to main content

Data loading orchestrator for separating loading and training into different processes, sharing data one-to-many and connected via TCP.

Project description

tensorsocket

Share PyTorch tensors over ZMQ sockets

Installation

Install as module. Should be installed in a project such as RAD/data-sharing

From source

From the root of the tensorsocket directory, install it with pip:

$ pip install .

From PyPi

$ pip install tensorsocket

Usage

tensorsocket works by exposing batches of data, represented as PyTorch tensors, on sockets that training processes can access. This allows for minimizing redundancy of training data during collocated tasks such as hyper-parameter tuning. Training with tensorsocket builds on the concept of a producer-consumer relationship, where the following example code shows how the producer wraps around an arbitrary data loader object. As with nested epoch-batch loops, one can iterate over the producer in the same manner as iterating over a data loader.

The use of tensorsocket relies on a TensorProducer and TensorConsumer. The TensorProducer can be used as is, however the TensorConsumer needs to be embedded in a class that exposes the same functionality as a PyTorch data loader, as shown in the SharedLoader example class, below:

# shared_data_loader.py

from tensorsocket import TensorConsumer

class SharedLoader(object):

    def __init__(self, port="5556", ack_port="5557"):
        self.consumer = TensorConsumer(port, ack_port)
        self.counter = 0

    def __iter__(self):
        self.counter = 0
        return self

    def __next__(self):
        if self.counter < self.__len__():
            self.counter += 1
            return next(self.consumer)
        else:
            raise StopIteration

    def __len__(self):
        return len(self.consumer)

Using the TensorProducer requires next to no additional implementation, apart from the original data loader.

# producer.py

data_loader = DataLoader(dataset)

producer = TensorProducer(data_loader, port="5556", ack_port="5557")

for _ in range(epochs):
        for _ in producer:
            pass
producer.join()

Given the SharedLoader with the TensorConsumer class, it is straightforward to modify a training script to fetch batches of data from the shared loader, rather than using the process-specific data loader, which is created for each collocated training job.

# consumer.py (or train.py)
from ... import SharedLoader

...

if not use_shared_loader:
    data_loader = create_loader(...)
else:
    data_loader = SharedLoader()

...

for batch_idx, (input, target) in enumerate(data_loader):
    output = model(input)
    ...

Features

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tensorsocket-0.0.3.tar.gz (16.4 kB view details)

Uploaded Source

Built Distribution

tensorsocket-0.0.3-py3-none-any.whl (8.8 kB view details)

Uploaded Python 3

File details

Details for the file tensorsocket-0.0.3.tar.gz.

File metadata

  • Download URL: tensorsocket-0.0.3.tar.gz
  • Upload date:
  • Size: 16.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-requests/2.31.0

File hashes

Hashes for tensorsocket-0.0.3.tar.gz
Algorithm Hash digest
SHA256 50c61ea131bd7515bd22ac461b492c82daa17c58140aa5a511233504566de791
MD5 8d357448a4ae4393df8ed551a653e519
BLAKE2b-256 127c9ca7171bedbe1f745658f8c6844f3a5a33a2767c2c74a0eca63fe3920570

See more details on using hashes here.

File details

Details for the file tensorsocket-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: tensorsocket-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 8.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-requests/2.31.0

File hashes

Hashes for tensorsocket-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 2c8bd23eeea0a023fc569e8706d84687d9e7d85ddb9696200367f0f7da49dd8c
MD5 f38cc7feb56d4873593c21ff5051ad43
BLAKE2b-256 a35440c35ecc7bd5499181d71f4b357d390992499c7b77b7f36c91458bab0447

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page