Skip to main content

SparkML base classes for Transformers and Estimators

Project description

This document includes an example of how to build a custom Estimator and Transformer using the base classes in this package, and how to integrate them with SparkML Pipelines. For information about the SparkML Pipelines concepts and use of existing Estimators and Transformers within the SparkML module, please refer to the Spark ML Pipelines documentation.

Installation

pip install -U sparkml-base-classes

Build a custom Transformer

In this section we build a Transformer that adds a constant to a column and updates the column’s values in-place.

import pyspark.sql.functions as F
from pyspark import keyword_only
from sparkml_base_classes import TransformerBaseClass


class AdditionColumnTransformer(TransformerBaseClass):

    @keyword_only
    def __init__(self, column_name=None, value=None):
        super().__init__()

    def _transform(self, ddf):
        self._logger.info("AdditionColumn transform with column {self._column_name}")
        ddf = ddf.withColumn(self._column_name, F.col(self._column_name) + self._value)
        return ddf

Build a custom Estimator

In this section we build an Estimator that normalizes the values of a column by the mean. An Estimator’s _fit method must return a Transformer because the use of an Estimator consists of 2 steps:

  1. Fitting the estimator.

    This step consists of using the _fit method to calculate some value(s) from the DataFrame and return a Transformer that stores the calculated value(s) and use them in the _transform method to transform a DataFrame. In this example the Estimator calculates the mean and returns a Transformer that divides the column by this mean value.

  2. Transforming the DataFrame.

    Once the Estimator has been fitted and a Transformer has been returned, then we use the returned Transformer to transform the DataFrame. In this case the Transformer divides the specified column by the mean and returns the transformed DataFrame.

import pyspark.sql.functions as F
from pyspark import keyword_only
from sparkml_base_classes import EstimatorBaseClass, TransformerBaseClass

class MeanNormalizerTransformer(TransformerBaseClass):

    @keyword_only
    def __init__(self, column_name=None, mean=None):
        super().__init__()

    def _transform(self, ddf):
        # add your transformation logic here
        self._logger.info("MeanNormalizer transform")
        ddf = ddf.withColumn(self._column_name, F.col(self._column_name) / self._mean)
        return ddf

class MeanNormalizerEstimator(EstimatorBaseClass):

    @keyword_only
    def __init__(self, column_name=None):
        super().__init__()

    def _fit(self, ddf):
        # add your transformation logic here
        self._logger.info("MeanNormalizer fit")
        mean, = ddf.agg(F.mean(self._column_name)).head()
        return MeanNormalizerTransformer(
            column_name=self._column_name,
            mean=mean
        )

Build the Pipeline

In this section we will build a Pipeline containing our custom Transformer and Estimator. We will first initialize both classes and then add them as stages to the Pipeline.

from pyspark.ml import Pipeline

multiply_column_transformer = AdditionColumnTransformer(column_name="foo", value=2)
mean_normalizer_estimator = MeanNormalizerEstimator(column_name="foo")
my_pipeline = Pipeline(stages=[multiply_column_transformer, mean_normalizer_estimator])

Fit the Pipeline and transform the DataFrame

In this section we will fit the created Pipeline to a DataFrame and then use the fitted Pipeline (or PipelineModel in SparkML terms) to transform a DataFrame. Thus, after a Pipeline’s fit method runs, it produces a PipelineModel, which is a Transformer. This PipelineModel can be later used to transform any DataFrame. Please refer to the Spark ML Pipelines documentation for an in-depth description.

from pyspark.sql import SparkSession
from pyspark.ml import Pipeline

spark = SparkSession.builder.getOrCreate()

ddf = spark.createDataFrame(
    [[1], [2], [3]],
    ["foo"],
)

# the returned object is of PipelineModel type
my_fitted_pipeline = my_pipeline.fit(ddf)
my_fitted_pipeline.transform(ddf).show()

+----+
| foo|
+----+
|0.75|
| 1.0|
|1.25|
+----+

Save and load fitted Pipeline

In the previous section we transformed the DataFrame immediately after fitting the Pipeline, in this section we will use an intermediary saving mechanism that allows us to decouple the fitting of the Pipeline from the transforming of the DataFrame.

from pyspark.ml import PipelineModel
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

ddf = spark.createDataFrame(
    [[8], [10], [12]],
    ["foo"],
)

my_fitted_pipeline.save('my_fitted_pipeline.pipeline')
my_fitted_pipeline = PipelineModel.load('my_fitted_pipeline.pipeline')
my_fitted_pipeline.transform(ddf).show()

+----+
| foo|
+----+
| 2.5|
|   3|
| 3.5|
+----+

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sparkml_base_classes-0.1.6.tar.gz (6.3 kB view details)

Uploaded Source

Built Distribution

sparkml_base_classes-0.1.6-py3-none-any.whl (6.1 kB view details)

Uploaded Python 3

File details

Details for the file sparkml_base_classes-0.1.6.tar.gz.

File metadata

  • Download URL: sparkml_base_classes-0.1.6.tar.gz
  • Upload date:
  • Size: 6.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.2.2 CPython/3.9.11 Darwin/21.6.0

File hashes

Hashes for sparkml_base_classes-0.1.6.tar.gz
Algorithm Hash digest
SHA256 bce8ca4b5169a1a40288efd2be2fcffa57099c9f225d1282b5e563fb7895bd99
MD5 7e04d6b2b55414181f07a7f213dc22b4
BLAKE2b-256 6a5c4fdf8b854034cb11f3e7cc809edd6526abe28fcd8d34b4c4f3c5480dba80

See more details on using hashes here.

File details

Details for the file sparkml_base_classes-0.1.6-py3-none-any.whl.

File metadata

File hashes

Hashes for sparkml_base_classes-0.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 f3baa9cd07af4727becd9b71aaf50eb59c318ac5bb4b2d996d30925231c98291
MD5 cc0722da4edf863e638af6f969b279d9
BLAKE2b-256 89a0610358abbca0d89a24153746a0de84dcd14fc17a4682494d9cfa75062e49

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page