Skip to main content

Sequence estimators for Tensorflow

Project description

tfseqestimator

Sequence estimators for TensorFlow.

Available estimators

  • FullSequenceClassifier: one class for a whole sequence
  • FullSequenceRegressor: one value for a whole sequence
  • SequenceItemsClassifier: one class for each sequence item
  • SequenceItemsRegressor: one value for each sequence item

Usage

from tfseqestimator import FullSequenceClassifier, RnnType
import tensorflow.contrib.feature_column as contrib_features

token_sequence = contrib_features.sequence_categorical_column_with_hash_bucket(...)
token_emb = contrib_features.embedding_column(categorical_column=token_sequence, ...)

estimator = FullSequenceClassifier(
    sequence_feature_columns=[token_emb],
    rnn_type=RnnType.REGULAR_STACKED_LSTM,
    rnn_layers=[32, 16]
)

# Input builders
def input_fn_train: # returns x, y
  pass
estimator.train(input_fn=input_fn_train, steps=100)

def input_fn_eval: # returns x, y
  pass
metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)

def input_fn_predict: # returns x, None
  pass
predictions = estimator.predict(input_fn=input_fn_predict)

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for tfseqestimator, version 2.1.1
Filename, size File type Python version Upload date Hashes
Filename, size tfseqestimator-2.1.1.tar.gz (30.1 kB) File type Source Python version None Upload date Hashes View hashes

Supported by

Elastic Elastic Search Pingdom Pingdom Monitoring Google Google BigQuery Sentry Sentry Error logging AWS AWS Cloud computing DataDog DataDog Monitoring Fastly Fastly CDN SignalFx SignalFx Supporter DigiCert DigiCert EV certificate StatusPage StatusPage Status page