Skip to main content

Input pipelines for TensorFlow that make sense.

Project description

# tf-inputs

This package provides easy-to-write input pipelines for TensorFlow that automatically
integrate with the `tf.data` API.

## Overview

A quick, full example of a training script with an optimized input pipeline:

```python
import tensorflow as tf
import tf_inputs as tfi

# Recursively find all files inside directory and parse them with `parse_fn`.
inputs = tfi.Input.from_directory(
"/path/to/data_dir", parse_fn=tf.image.decode_png, batch_size=16,
num_parallel_calls=4
)

# Supposing `my_model_fn` builds the computational graph of some image model.
# Built Keras style-- calling the instance returns the iterator input tensor,
# and until this is done, no ops are added to the computational graph.
train_op, outputs = my_model_fn(inputs())

# Training loop.
with tf.Session().as_default():
inputs.initialize() # or `sess.run(inputs.initializer)` is fine too
while True:
try:
inputs.run(train_op) # replace `sess.run` with `inputs.run`
except tf.errors.OutOfRangeError:
break
```

You may still use `sess.run` if you prefer, though we override it to automatically
handle `feed_dict` passing for TensorFlow's feedable iterators and placeholders. If you
need to pass an explicit session you may also use `inputs.run(ops, session=sess)`.

## Installation

`tf-inputs` supports TensorFlow 1.13 and python 3.7. We use no other 3rd party python
modules. Make sure to have your favorite TensorFlow binary installed (i.e., `tensorflow`,
`tensorflow-gpu`, or your own custom wheel built from source) prior to installing
`tf-inputs`.

```
pip install tf-inputs
```

## Switch between training and validation datasets

This can get quite messy with the `tf.data` API. See the
[documentation](https://www.tensorflow.org/guide/datasets#creating_an_iterator)
yourself. `tf-inputs` handles it this way:

```python
train_inputs = tfi.Input.from_directory("/data/training", **options)
valid_inputs = tfi.Input.from_directory("/data/validation", **options)
inputs = tfi.TrainValidInput(train_inputs, valid_inputs)

...

with tf.Session().as_default():
inputs.initialize()
inputs.run([train_op, output_op]) # receives a training batch
inputs.run(output_op, valid=True) # receives a validation batch
```

If you do not have separate datasets for training and validation, you may use:

```python
inputs = tfi.TrainValidSplit(inputs, num_valid_examples)
```

## Methods to read data

`tf-inputs` supports a variety of ways to read data besides `Input.from_directory`:

```python
# Provide the file paths yourself:
inputs = tfi.Input.from_file_paths(["data/file1.txt", "data/file2.txt"], **options)
```

```python
# Provide the `tf.data.Dataset` instance yourself (yielding single input elements):
inputs = tfi.Input.from_dataset(dataset, **blah)
```

```python
# Same as above, but preventing any graph building a priori:
inputs = tfi.Input.from_dataset_fn(get_dataset, **blah)
```

```python
# Lowest level: subclass `tfi.Input` and override `read_data` to return the dataset:
class MyInput(tfi.Input):
def __init__(self, my_arg, my_kwarg="foo", **kwargs):
super().__init__(**kwargs)
self.my_arg = myarg
...

def read_data(self):
return tf.data.Dataset.from_tensor_slices(list(range(self.my_arg)))
```

Usually there is no need to use the lower level methods. One common example is when the
user wishes to yield `(input, label)` pairs and they live in different files. You may
use `tfi.Zip` for this, as long as the number of elements match:

```python
# Multi task learning training input pipeline.
sentences_en = tfi.Input.from_directory("data/training/english")
sentences_fr = tfi.Input.from_directory("data/training/french")
sentiment_labels = tfi.Input.from_directory("data/training/labels")

inputs = tfi.Zip(images, sentences_fr, sentiment_labels)

def my_model(inputs, training=True):
if training:
x, y1, y2 = inputs
...
```

## Training over multiple epochs

Just catch the `tf.errors.OutOfRangeError` and restart the iterator:

```python
# Inside a `tf.Session`:
inputs.initialize()
while epoch < max_epochs:
try:
inputs.run(train_op)
except tf.errors.OutOfRangeError:
inputs.initialize()
epochs += 1
```

## Multiple elements per file

Just set `flatten=True` flag with `Input.from_directory` or `Input.from_file_paths`:

```python
# Inputs split by an arbitrary delimiter in a text file:
inputs = tfi.Input.from_directory(
'path/to/text/files', batch_size=32, flatten=True,
parse_fn=lambda x: tf.string_split(x, delimiter='\n\n'),
)
```


Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tf-inputs-0.1.0.tar.gz (10.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tf_inputs-0.1.0-py3-none-any.whl (10.0 kB view details)

Uploaded Python 3

File details

Details for the file tf-inputs-0.1.0.tar.gz.

File metadata

  • Download URL: tf-inputs-0.1.0.tar.gz
  • Upload date:
  • Size: 10.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.21.0 setuptools/40.8.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.7.2

File hashes

Hashes for tf-inputs-0.1.0.tar.gz
Algorithm Hash digest
SHA256 93eecd255ede8f7ab5b14de09d481403a9a60628bdbb2d92790c57a492cdae9e
MD5 7cc2ac18c630190bebf831c36b407075
BLAKE2b-256 a31600911d237d9c85174045cbd242b5f1a305ea6b4185209a118ba15496c06c

See more details on using hashes here.

File details

Details for the file tf_inputs-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: tf_inputs-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 10.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.21.0 setuptools/40.8.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.7.2

File hashes

Hashes for tf_inputs-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d8b8159c9c7b2811654c8f3650db443c169bae3389a42be7624ab68c6a7e59a3
MD5 2f5310035b7bcfcf6f775d8d0dd2c9b5
BLAKE2b-256 dd69c55f23ebec14edaa1b2c61d491ecb9edb29c6ae5922d5193c7182b42769a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page