seamless support for spark datasets in keras .fit() and .predict()
Project description
Spark-Keras Integration
This package enables seamless integration of PySpark DataFrames with Keras models, allowing users to efficiently train and predict using distributed data.
Quickstart Guide
Setting Up the PySpark DataFrame
Start by creating a PySpark DataFrame as shown below:
import pandas as pd
from pyspark.sql import SparkSession
import tensorflow as tf
# Initialize a Spark session
spark = SparkSession.builder.appName("spark_keras").getOrCreate()
# Create a PySpark DataFrame
spark_df = spark.createDataFrame(
pd.DataFrame({
"feature1": tf.random.normal([100]).numpy().tolist(),
"label1": tf.random.normal([100]).numpy().tolist(),
"partition_id": [0 for _ in range(100)]
})
)
Training and Predicting with KerasSparkModel
You can fit and predict with KerasSparkModel
using the standard Keras API:
import tensorflow as tf
from keras_spark.models import KerasSparkModel as Model
# Define the Keras model
input_layer = tf.keras.Input(shape=[1], name="feature1")
output_layer = tf.keras.layers.Dense(1, name="label1")(input_layer)
model = Model(input_layer, output_layer)
model.compile("adam","mean_squared_error")
# Train the model using the PySpark DataFrame
model.fit(spark_df, batch_size=10, epochs=100,partition_col="partition_id")
# Use Spark for distributed scoring on the PySpark DataFrame
predictions = model.predict(spark_df).select("model_output.label1")
Important Considerations
-
Naming Conventions:
- Input Names: Each Keras input must have a specified name that corresponds to the respective PySpark DataFrame column.
- Output Names: Output names are inferred from the Keras output layers and must match the PySpark column names if using
.fit()
.
-
Data Type Compatibility:
- Ensure that the data types of Keras inputs and the corresponding PySpark columns are compatible.
-
Partitioning Requirements:
- The PySpark DataFrame must include a
partition_id
column, with values ranging from0
tonr_partitions
. - Choose
nr_partitions
carefully to ensure that the Spark driver can handle the workload. - Parallel processing of partitions is handled using
.interleave()
, with the degree of parallelism set bynum_parallel_calls
.
- The PySpark DataFrame must include a
-
Prediction Output:
- The
.predict()
method generates an additional struct column namedmodel_output
. - To access specific outputs, reference them using their keys, e.g.,
model_output.label1
.
- The
-
Keras Version:
- This package is compatible with Keras version 3.0 and above.
This integration empowers users to leverage distributed data processing with PySpark while taking full advantage of Keras's deep learning capabilities.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
File details
Details for the file keras_spark-0.99.tar.gz
.
File metadata
- Download URL: keras_spark-0.99.tar.gz
- Upload date:
- Size: 12.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.18
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d46b509e3287cfdfee9cf34aa80d07e8968a6be50bc1b799cc1fb3972d0d7eb8 |
|
MD5 | 57bb68f0d20edf7a368a358150c63ca4 |
|
BLAKE2b-256 | e6982c7e4b72731326c06cf85c7974423ee96dfccf0d8c5a6c4c4cd3fa8f956f |
File details
Details for the file keras_spark-0.99-py3-none-any.whl
.
File metadata
- Download URL: keras_spark-0.99-py3-none-any.whl
- Upload date:
- Size: 13.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.18
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8c9a66f7e1c05c38fa4b50c218f68a70ddf4d165010332436d0738fca302bc74 |
|
MD5 | 6f87e19fc5b3dfa255fbb1eb6421d65c |
|
BLAKE2b-256 | 9a9a53847281a9a2eb5a509239f05eff05ea692c859e297cfa4ec6a75adf444f |
File details
Details for the file keras_spark-0.99-py2.py3-none-any.whl
.
File metadata
- Download URL: keras_spark-0.99-py2.py3-none-any.whl
- Upload date:
- Size: 13.4 kB
- Tags: Python 2, Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.18
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 911f26ac02aff5f7195a25912f24a0811fbe981397e8a221459b76d33e492a77 |
|
MD5 | 3ed2abccdadc9c7c737d324eac9c3e74 |
|
BLAKE2b-256 | e1b7998f21365b14b41fbaeb484bdb37dcddcf83e981129e84e673ae23ed10b1 |