Skip to main content

Dataclasses for TensorFlow

Project description

tf-dataclass

Support python dataclass containers as input and output in callable TensorFlow 2 graph.

Install

Make sure that tensorflow>=2.0.0 or tensorflow-gpu>=2.0.0 is installed.

$ pip install tf-dataclass

Why

TensorFlow 2 autograph function supports only nested structures of python tuples as inputs and output. (Outputs can be also python dictionaries.) This is inconvenient once we go beyond small hello world cases, because we have to work with unstructured armfuls of tensors. This small package is dedicated to fill this gap by letting @tf.function decorated functions to take and return pythonic dataclass instancies.

Examples of usage:

1. Sequential features

import tensorflow as tf
import tf_dataclass

# Batch of sequential features of different length
@tf_dataclass.dataclass
class Sequential:
    feature: tf.Tensor  # shape = [batch, length, channels],    dtype = tf.float32
    length: tf.Tensor   # shape = [batch],                      dtype = tf.int32

# Initialize a batch of two sequences of lengths 6 and 4
input = Sequential(
    feature = tf.random.normal(shape=[2, 6, 3]),
    length = tf.constant([6, 4], dtype=tf.int32),
)
    
# Define a convolution operator with a stride such that length -> length / stride
@tf_dataclass.function
def convolution(input: Sequential, filters: tf.Tensor, stride: int) -> Sequential:
    return Sequential(
        feature = tf.nn.conv1d(input.feature, filters, stride),
        length = tf.math.floordiv(input.length, stride),
    )

# Output is an instance of Sequential with lengths 3 and 2 due to convolution stride = 2
output = convolution(
    input = input,
    filters = tf.random.normal(shape=[1, 3, 7]),
    stride = 2,
)
assert isinstance(output, Sequential)
print(output.length) # -> tf.Tensor([3 2], shape=(2,), dtype=int32)

2. Minibatch as a data transfer object:

import tensorflow as tf
import tf_dataclass

@tf_dataclass.dataclass
class DataBatch:
    image: tf.Tensor            # shape = [batch, height, width, channels], dtype = tf.flaot32
    label: tf.Tensor            # shape = [batch],                          dtype = tf.int32
    image_file_path: tf.Tensor  # shape = [batch],                          dtype = tf.string
    dataset_name: tf.Tensor     # shape = [batch],                          dtype = tf.string
    ...
    
@tf_dataclass.function
def train_step(input: DataBatch) -> None:
    ...

3. Containerized outputs:

import tensorflow as tf
import tf_dataclass

@tf_dataclass.dataclass
class ModelOutput:
    loss_value: tf.Tensor   # shape = [batch],  dtype = tf.flaot32
    label: tf.Tensor        # shape = [batch],  dtype = tf.int32
    prediction: tf.Tensor   # shape = [batch],  dtype = tf.int32
    ...
    
    @property
    def mean_loss(self) -> tf.Tensor: # shape = [batch],  dtype = tf.float32
        return tf.reduce_mean(self.loss_value)
    
    @property
    def num_true_predictions(self) -> tf.Tensor: # shape = [batch],  dtype = tf.int32
        return tf.reduce_sum(tf.cast(self.label == self.prediction, dtype=tf.int32))

    @property
    def num_false_predictions(self) -> tf.Tensor: # shape = [batch],  dtype = tf.int32
        return tf.reduce_sum(tf.cast(self.label != self.prediction, dtype=tf.int32))

    ...

@tf_dataclass.function
def get_loss(...) -> ModelOutput:
    ...

Such containers can be merged along datasets and workers.

4. Easy tensorflow shape and dtype runtime verification:

import tensorflow as tf
import tf_dataclass

@tf_dataclass.dataclass
class Sequential:
    feature: tf.Tensor  # shape = [batch, length, channels], dtype = tf.flaot32
    length: tf.Tensor   # shape = [batch]                    dtype = tf.int32

    def __post_init__(self):
        # Verify feature
        assert self.feature.dtype == tf.float32
        assert len(self.feature.shape) == 3
        
        # Verify length
        assert self.length.dtype == tf.int32
        assert len(self.length.shape) == 1
        
        # Verify batch size
        # Works only in eager mode for better perfomance  
        assert self.feature.shape[0] == self.length.shape[0]

    @property
    def batch_size(self) -> tf.Tensor: # shape = [], dtype = tf.int32
        return tf.shape(self.feature)[0]

Other features:

  • Support hierarchical composition.
  • Support inheritance including multiple one (for free from original dataclass).
  • Highliting, autocomplete, and refactoring from your IDE.

Usage

  1. Import dataclass and function from tf_dataclass
from tf_dataclasses import dataclass, function
  1. It is mandatory to use return type hints for the function decorated with @function. For example,
from typing import Tuple

@dataclass
class MyDataclass:
    ...

@function
def my_func(...) -> Tuple[tf.Tensor, MyDataclass]:
    ...
    return some_tensor, my_dataclass_instance
  1. Type hints for the arguments are optional but recommended.

  2. Positional arguments are not currently supported:

For example, for

@function
def my_graph_func(x: ..., y: ...) -> ... :
    ...

type

my_graph_func(x=x, y=y)

but not

my_graph_func(x, y)

Known Problems

  1. IDE autocomplete is currently not well-supported, for example, in PyCharm. Solution: use import
from typing import TYPE_CHECKING
if TYPE_CHECKING:
    from dataclasses import dataclass
else:
    from tf_dataclass import dataclass

in each *.py file where dataclass is used.

Under the roof

Dataclasses and their nested structures are simply converted into nested pythonic tuples and back. This way we wrap given functions such that all inputs and outputs are nested tuples. Then @tf.function is applied. Afterward the graph function is wrapped bach to dataclass form. Type hints are used in python runtime for the graph creation as temples to pack and unpack dataclass arguments.

Future plans

  1. Support tf.cond, tf.case, tf.switch_case, tf.while_loop, tf.Optional, and tf.data.Iterator.
  2. Support positional arguments.
  3. Conversion to tf.nest structures.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tf-dataclass-0.1.3.tar.gz (12.2 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page