Skip to main content

Package to streamline reading and writing data to tfrecord files

Project description

easy_tfrecords

this package is designed to assist reading and writing to tfrecord files in an intuitive way that preserves dtype and data structure

Purpose:

The tfrecord format is a fast and powerful way of feeding data to a tensorflow model; it can automatically batch, randomize and iterate your data across epochs without special instructions. The problem with using tfrecord files comes from orchestrating the madness of matching feature structures across the reader, writer and fetcher.

The easy_tfrecords module contains methods and classes that allow you to write to and read from tfrecord files in a straightforward, extensible manner.

Features:

  • create tfrecord files
  • read from single or multiple tfrecord files
  • selectively read data from tfrecord files
  • examine the data structure of tfrecord files

Usage:

Writing

  • Import data into python however you normally would (excel, pandas, csv, matlab, etc.)
  • Reshape each of your arrays of features to shape=[N, x[, y[, z[, etc.]]]] where N is the number of features.
    • Add multiple lists of features to the file as key-value pairs

Reading

  • Create a reader class object, specifying your file list (can be length 1), optionally specifying batch size and shuffe spec.
  • pass a list of which inputs to read from the file

Example Code:

import numpy as np
import tensorflow as tf

from easy_tfrecords import create_tfrecords, easy_tfrecords as records


# CREATE SOME TEST DATA
x      = np.array([[0, 0, 0, 0], [0, 0, 0, 0]], np.int32)
trainX = np.asarray( [x, x+1, x+2] )

y      = np.array([0.25], np.float32)
trainY = np.asarray( [y, y+1, y+2] )


# CREATE AND SAVE TO A FEW TFRECORDS FILES
create_tfrecords('tfr_1.tf', x=trainX, y=trainY)
create_tfrecords('tfr_2.tf', x=trainX+10, y=trainY+10)
create_tfrecords('tfr_3.tf', x=trainX+100, y=trainY+100, z=trainY+100)

# INSTANTIATE THE RECORDS OBJECT
rec = records(files=['data_1.tf', 'data_2.tf'],
  shuffle=False,
  batch_size=1, 
  keys=['x', 'y'])

next_factory = rec.get_next_factory()

batch_x = next_factory['x']
batch_y = next_factory['y']

with tf.Session() as sess:

  sess.run(rec.get_initializer())

  for n in range(10):
    print('------------')
    print('n => {}\n'.format(n))

    x_eval, y_eval = sess.run( [batch_x, batch_y] )
    print('x_eval=\n{}\n'.format(x_eval))
    print('y_eval=\n{}'.format(y_eval))

sess.close()

Output :

------------
n => 0

x_eval=
[[ 0.25]]

y_eval=
[[[0 0 0 0]
  [0 0 0 0]]]
------------
n => 1

x_eval=
[[ 1.25]]

y_eval=
[[[1 1 1 1]
  [1 1 1 1]]]
------------
n => 2

x_eval=
[[ 2.25]]

y_eval=
[[[2 2 2 2]
  [2 2 2 2]]]
------------
n => 3

x_eval=
[[ 100.25]]

y_eval=
[[[100 100 100 100]
  [100 100 100 100]]]
------------
n => 4

x_eval=
[[ 101.25]]

y_eval=
[[[101 101 101 101]
  [101 101 101 101]]]
------------
n => 5

x_eval=
[[ 102.25]]

y_eval=
[[[102 102 102 102]
  [102 102 102 102]]]
------------
n => 6

x_eval=
[[ 10.25]]

y_eval=
[[[10 10 10 10]
  [10 10 10 10]]]
------------
n => 7

x_eval=
[[ 11.25]]

y_eval=
[[[11 11 11 11]
  [11 11 11 11]]]
------------
n => 8

x_eval=
[[ 12.25]]

y_eval=
[[[12 12 12 12]
  [12 12 12 12]]]
------------
n => 9

x_eval=
[[ 0.25]]

y_eval=
[[[0 0 0 0]
  [0 0 0 0]]]

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

easy_tfrecords-0.1.0.tar.gz (5.1 kB view hashes)

Uploaded source

Built Distribution

easy_tfrecords-0.1.0-py3-none-any.whl (5.3 kB view hashes)

Uploaded py3

Supported by

AWS AWS Cloud computing Datadog Datadog Monitoring Facebook / Instagram Facebook / Instagram PSF Sponsor Fastly Fastly CDN Google Google Object Storage and Download Analytics Huawei Huawei PSF Sponsor Microsoft Microsoft PSF Sponsor NVIDIA NVIDIA PSF Sponsor Pingdom Pingdom Monitoring Salesforce Salesforce PSF Sponsor Sentry Sentry Error logging StatusPage StatusPage Status page