Skip to main content

Input pipelines for TensorFlow that make sense.

Project description

# tf-inputs

This package provides easy-to-write input pipelines for TensorFlow that automatically
integrate with the `` API.

## Overview

A quick, full example of a training script with an optimized input pipeline:

import tensorflow as tf
import tf_inputs as tfi

# Recursively find all files inside directory and parse them with `parse_fn`.
inputs = tfi.Input.from_directory(
"/path/to/data_dir", parse_fn=tf.image.decode_png, batch_size=16,

# Supposing `my_model_fn` builds the computational graph of some image model.
# Built Keras style-- calling the instance returns the iterator input tensor,
# and until this is done, no ops are added to the computational graph.
train_op, outputs = my_model_fn(inputs())

# Training loop.
with tf.Session().as_default():
inputs.initialize() # or `` is fine too
while True:
try: # replace `` with ``
except tf.errors.OutOfRangeError:

You may still use `` if you prefer, though we override it to automatically
handle `feed_dict` passing for TensorFlow's feedable iterators and placeholders. If you
need to pass an explicit session you may also use `, session=sess)`.

## Installation

`tf-inputs` supports TensorFlow 1.13 and python 3.7. We use no other 3rd party python
modules. Make sure to have your favorite TensorFlow binary installed (i.e., `tensorflow`,
`tensorflow-gpu`, or your own custom wheel built from source) prior to installing

pip install tf-inputs

## Switch between training and validation datasets

This can get quite messy with the `` API. See the
yourself. `tf-inputs` handles it this way:

train_inputs = tfi.Input.from_directory("/data/training", **options)
valid_inputs = tfi.Input.from_directory("/data/validation", **options)
inputs = tfi.TrainValidInput(train_inputs, valid_inputs)


with tf.Session().as_default():
inputs.initialize()[train_op, output_op]) # receives a training batch, valid=True) # receives a validation batch

If you do not have separate datasets for training and validation, you may use:

inputs = tfi.TrainValidSplit(inputs, num_valid_examples)

## Methods to read data

`tf-inputs` supports a variety of ways to read data besides `Input.from_directory`:

# Provide the file paths yourself:
inputs = tfi.Input.from_file_paths(["data/file1.txt", "data/file2.txt"], **options)

# Provide the `` instance yourself (yielding single input elements):
inputs = tfi.Input.from_dataset(dataset, **blah)

# Same as above, but preventing any graph building a priori:
inputs = tfi.Input.from_dataset_fn(get_dataset, **blah)

# Lowest level: subclass `tfi.Input` and override `read_data` to return the dataset:
class MyInput(tfi.Input):
def __init__(self, my_arg, my_kwarg="foo", **kwargs):
self.my_arg = myarg

def read_data(self):

Usually there is no need to use the lower level methods. One common example is when the
user wishes to yield `(input, label)` pairs and they live in different files. You may
use `tfi.Zip` for this, as long as the number of elements match:

# Multi task learning training input pipeline.
sentences_en = tfi.Input.from_directory("data/training/english")
sentences_fr = tfi.Input.from_directory("data/training/french")
sentiment_labels = tfi.Input.from_directory("data/training/labels")

inputs = tfi.Zip(images, sentences_fr, sentiment_labels)

def my_model(inputs, training=True):
if training:
x, y1, y2 = inputs

## Training over multiple epochs

Just catch the `tf.errors.OutOfRangeError` and restart the iterator:

# Inside a `tf.Session`:
while epoch < max_epochs:
except tf.errors.OutOfRangeError:
epochs += 1

## Multiple elements per file

Just set `flatten=True` flag with `Input.from_directory` or `Input.from_file_paths`:

# Inputs split by an arbitrary delimiter in a text file:
inputs = tfi.Input.from_directory(
'path/to/text/files', batch_size=32, flatten=True,
parse_fn=lambda x: tf.string_split(x, delimiter='\n\n'),

Project details

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for tf-inputs, version 0.2.3
Filename, size File type Python version Upload date Hashes
Filename, size tf_inputs-0.2.3-py3-none-any.whl (9.9 kB) File type Wheel Python version py3 Upload date Hashes View
Filename, size tf-inputs-0.2.3.tar.gz (10.2 kB) File type Source Python version None Upload date Hashes View

Supported by

AWS AWS Cloud computing Datadog Datadog Monitoring DigiCert DigiCert EV certificate Facebook / Instagram Facebook / Instagram PSF Sponsor Fastly Fastly CDN Google Google Object Storage and Download Analytics Pingdom Pingdom Monitoring Salesforce Salesforce PSF Sponsor Sentry Sentry Error logging StatusPage StatusPage Status page