Skip to main content

seamless support for spark datasets in keras .fit() and .predict()

Project description

Spark-Keras Integration

This package enables seamless integration of PySpark DataFrames with Keras models, allowing users to efficiently train and predict using distributed data.

Quickstart Guide

Setting Up the PySpark DataFrame

Start by creating a PySpark DataFrame as shown below:

import pandas as pd
from pyspark.sql import SparkSession
import tensorflow as tf

# Initialize a Spark session
spark = SparkSession.builder.appName("spark_keras").getOrCreate()

# Create a PySpark DataFrame
spark_df = spark.createDataFrame(
    pd.DataFrame({
        "feature1": tf.random.normal([100]).numpy().tolist(),
        "label1": tf.random.normal([100]).numpy().tolist(),
        "partition_id": [0 for _ in range(100)]
    })
)

Training and Predicting with KerasSparkModel

You can fit and predict with KerasSparkModel using the standard Keras API:

import tensorflow as tf
from keras_spark.models import KerasSparkModel as Model

# Define the Keras model
input_layer = tf.keras.Input(shape=[1], name="feature1")
output_layer = tf.keras.layers.Dense(1, name="label1")(input_layer)
model = Model(input_layer, output_layer)
model.compile("adam","mean_squared_error")

# Train the model using the PySpark DataFrame
model.fit(spark_df, batch_size=10, epochs=100,partition_col="partition_id")

# Use Spark for distributed scoring on the PySpark DataFrame
predictions = model.predict(spark_df).select("model_output.label1")

Important Considerations

  1. Naming Conventions:

    • Input Names: Each Keras input must have a specified name that corresponds to the respective PySpark DataFrame column.
    • Output Names: Output names are inferred from the Keras output layers and must match the PySpark column names if using .fit().
  2. Data Type Compatibility:

    • Ensure that the data types of Keras inputs and the corresponding PySpark columns are compatible.
  3. Partitioning Requirements:

    • The PySpark DataFrame must include a partition_id column, with values ranging from 0 to nr_partitions.
    • Choose nr_partitions carefully to ensure that the Spark driver can handle the workload.
    • Parallel processing of partitions is handled using .interleave(), with the degree of parallelism set by num_parallel_calls.
  4. Prediction Output:

    • The .predict() method generates an additional struct column named model_output.
    • To access specific outputs, reference them using their keys, e.g., model_output.label1.
  5. Keras Version:

    • This package is compatible with Keras version 3.0 and above.

This integration empowers users to leverage distributed data processing with PySpark while taking full advantage of Keras's deep learning capabilities.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras_spark-0.99.tar.gz (12.6 kB view details)

Uploaded Source

Built Distributions

keras_spark-0.99-py3-none-any.whl (13.4 kB view details)

Uploaded Python 3

keras_spark-0.99-py2.py3-none-any.whl (13.4 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file keras_spark-0.99.tar.gz.

File metadata

  • Download URL: keras_spark-0.99.tar.gz
  • Upload date:
  • Size: 12.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.18

File hashes

Hashes for keras_spark-0.99.tar.gz
Algorithm Hash digest
SHA256 d46b509e3287cfdfee9cf34aa80d07e8968a6be50bc1b799cc1fb3972d0d7eb8
MD5 57bb68f0d20edf7a368a358150c63ca4
BLAKE2b-256 e6982c7e4b72731326c06cf85c7974423ee96dfccf0d8c5a6c4c4cd3fa8f956f

See more details on using hashes here.

File details

Details for the file keras_spark-0.99-py3-none-any.whl.

File metadata

  • Download URL: keras_spark-0.99-py3-none-any.whl
  • Upload date:
  • Size: 13.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.18

File hashes

Hashes for keras_spark-0.99-py3-none-any.whl
Algorithm Hash digest
SHA256 8c9a66f7e1c05c38fa4b50c218f68a70ddf4d165010332436d0738fca302bc74
MD5 6f87e19fc5b3dfa255fbb1eb6421d65c
BLAKE2b-256 9a9a53847281a9a2eb5a509239f05eff05ea692c859e297cfa4ec6a75adf444f

See more details on using hashes here.

File details

Details for the file keras_spark-0.99-py2.py3-none-any.whl.

File metadata

  • Download URL: keras_spark-0.99-py2.py3-none-any.whl
  • Upload date:
  • Size: 13.4 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.18

File hashes

Hashes for keras_spark-0.99-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 911f26ac02aff5f7195a25912f24a0811fbe981397e8a221459b76d33e492a77
MD5 3ed2abccdadc9c7c737d324eac9c3e74
BLAKE2b-256 e1b7998f21365b14b41fbaeb484bdb37dcddcf83e981129e84e673ae23ed10b1

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page