Skip to main content

An open-source Python library for simplifying local testing of Databricks workflows that use PySpark and Delta tables.

Project description

pysparkdt (PySpark Delta Testing)

Supported Python versions Package version PyPI - Downloads Ruff

An open-source Python library for simplifying local testing of Databricks workflows using PySpark and Delta tables.

This library enables seamless testing of PySpark processing logic outside Databricks by emulating Unity Catalog behavior. It dynamically generates a local metastore to mimic Unity Catalog and supports simplified handling of Delta tables for both batch and streaming workloads.

Guideline

Table of Contents

Overview

Scope

This guideline helps you test Databricks Python pipelines with a focus on PySpark code. While basic unit testing knowledge with pytest is helpful, it's not the central focus.

Key Points

  • Standalone Testing: The setup allows you to test code without Databricks access, enabling easy CI integration.

  • Local Metastore: Mimic the Databricks Unity Catalog using a dynamically generated local metastore with local Delta tables.

  • Code Testability: Move core processing logic from notebooks to Python modules. Notebooks then serve as entrypoints.

Setup

In the following section we will assume that you are creating tests for a job which has one delta table on input and produces one delta table on output. It utilizes PySpark for its processing.

1. Installation

Install pysparkdt

  • Get this package from the pypi. It's only needed in your test environment.
pip install pysparkdt

2. Testable code

  • Modularization: Move processing logic from notebooks to modules.

  • Notebook Role: Notebooks primarily handle initialization and triggering processing. They should contain all the code specific to Databricks (e.g. dbutils usage)

entrypoint.py (Databricks Notebook)
# Databricks notebook source
import sys
from pathlib import Path

MODULE_DIR = Path.cwd().parent
sys.path.append(MODULE_DIR.as_posix())

# COMMAND ----------

from myjobpackage.processing import process_data

# COMMAND ----------

input_table = dbutils.widgets.get('input_table')
output_table = dbutils.widgets.get('output_table')

# COMMAND ----------

process_data(
    spark=spark,
    input_table=input_table,
    output_table=output_table,
)

myjobpackage.processing

  • Contains the core logic to test
  • Our test focuses on the core function myjobpackage.processing.process_data

3. File structure

myjobpackage
├── __init__.py
├── entrypoint.py  # Databricks Notebook
└── processing.py
tests
├── __init__.py
├── test_processing.py
└── data
    └── tables
        ├── example_input.ndjson
        ├── expected_output.ndjson
        └── schema
            ├── example_input.json
            └── expected_output.json

Data Format

  • Test Data: Newline-delimited JSON (.ndjson)
  • Optional Schema: JSON
example_input.ndjson
{"id": 0, "time_utc": "2024-01-08T11:00:00", "name": "Jorge", "feature": 0.5876}
{"id": 1, "time_utc": "2024-01-11T14:28:00", "name": "Ricardo", "feature": 0.42}
example_input.json
{
    "type": "struct",
    "fields": 
    [
        {
            "name": "id",
            "type": "long",
            "nullable": false,
            "metadata": {}
        },
        {
            "name": "time_utc",
            "type": "timestamp",
            "nullable": false,
            "metadata": {}
        },
        {
            "name": "name",
            "type": "string",
            "nullable": true,
            "metadata": {}
        },
        {
            "name": "feature",
            "type": "double",
            "nullable": true,
            "metadata": {}
        }
    ]
}

Tip: A schema file for a loaded PySpark DataFrame df can be created using:

with(open('example_input.json', 'w')) as file:
  file.write(json.dumps(df.schema.jsonValue(), indent=4))

Thus, you can first load a table without a schema, then create schema file from it and modify the types to the desired one.

4. Tests

Constants: Define paths for test data and the temporary metastore.

DATA_DIR = f'{os.path.dirname(__file__)}/data'
JSON_TABLES_DIR = f'{DATA_DIR}/tables'
TMP_DIR = f'{DATA_DIR}/tmp'
METASTORE_DIR = f'{TMP_DIR}/metastore'

Spark Fixture: Define fixture for the local spark session using spark_base function from the testing package. Specify the temporal metastore location.

from pytest import fixture
from pysparkdt import spark_base

@fixture(scope='module')
def spark():
    yield from spark_base(METASTORE_DIR)

Metastore Initialization: Use reinit_local_metastore

At the beginning of your test method call reinit_local_metastore function from the testing package to initialize the metastore with the tables from your json folder (JSON_TABLES_DIR). You can also choose to enable or disable deletion vectors for Delta tables (default: enabled). If the method is called while the metastore already exists, it will delete all the existing tables before initializing the new ones.

Alternatively, you can call this method only once per testing module, but then individual testing methods might affect each other by modifying metastore tables.

from myjobpackage.processing import process_data
from pysparkdt import reinit_local_metastore
from pyspark.testing import assertDataFrameEqual

def test_process_data(
    spark: SparkSession,
):
    reinit_local_metastore(spark, JSON_TABLES_DIR, deletion_vectors=True)
    
    process_data(
        spark=spark,
        input_table='example_input',
        output_table='output',
    )
    
    output = spark.read.format('delta').table('output')
    expected = spark.read.format('delta').table('expected_output')
    
    assertDataFrameEqual(
        actual=output.select(sorted(output.columns)),
        expected=expected.select(sorted(expected.columns)),
    )

In the example above, we use assertDataFrameEqual to compare PySpark DataFrames. We ensure the columns are ordered so that the order of result columns does not matter. By default, the order of rows does not matter in assertDataFrameEqual (this can be adjusted using the checkRowOrder parameter).

ℹ️ For complete example, please look at example.

⚠️ Manual deletion of local metastore

Deleting the local metastore manually invalidates any Spark session configured for that location. You would need to start a new Spark session because the original session’s state is no longer valid. Avoid manual deletion — use reinit_local_metastore for reinitialization instead.

⚠️ Note on running tests in parallel

With the setup above, the metastore is shared on the module scope. Therefore, if tests defined in the same module are run in parallel, race conditions can occur if multiple test functions use the same tables.

To mitigate this, make sure each test in the module uses its own set of tables.

Advanced

Testing Stream Processing

Let's now focus on a case where a job is reading input delta table using PySpark streaming, performing some computation on the data and saving it to the output delta table.

In order to be able to test the processing we need to explicitly wait for its completion. The best way to do it is to await the streaming function performing the processing.

To be able to await the streaming function, the test function needs to have access to it. Thus, we need to make sure the streaming function (query in Databricks terms) is accessible - for example by returning it by the processing function.

myjobpackage/processing.py
def process_data(
    spark: SparkSession,
    input_table: str, 
    output_table: str, 
    checkpoint_location: str,
) -> StreamingQuery:
  load_query = spark.readStream.format('delta').table(input_table)
    
  def process_batch(df: pyspark.sql.DataFrame, _) -> None:
      ... process df ...
      df.write.mode('append').format('delta').saveAsTable(output_table)

  return (
      load_query.writeStream.format('delta')
      .foreachBatch(process_batch)
      .trigger(availableNow=True)
      .option('checkpointLocation', checkpoint_location)
      .start()
  )
myjobpackage/tests/test_processing.py
def test_process_data(spark: SparkSession):
    ...
    spark_processing = process_data(
        spark=spark,
        input_table_name='example_input',
        output_table='output',
        checkpoint_location=f'{TMP_DIR}/_checkpoint/output',
    )
    spark_processing.awaitTermination(60)
    
    output = spark.read.format('delta').table('output')
    expected = spark.read.format('delta').table('expected_output')
    ...

Mocking Inside RDD and UDF Operations

If we are testing whole job’s processing code and inside it we have functions executed through rdd.mapPartitions, rdd.map, or UDFs, we need to add special handling for mocking as regular patching does not propagate to worker nodes.

myjobpackage/processing.py
myjobpackage/processing.py

def call_api(
    data_df: pyspark.sql.DataFrame,
) -> pyspark.sql.DataFrame:
    # Call API in parallel (session per partition)
    result = data_df.rdd.mapPartitions(_partition_run).toDF()
    return result
  
def _partition_run(
    iterable: Iterable[Row],
) -> Iterable[dict[str, Any]]:
  with ApiSessionClient() as client:
      for row in iterable:
          ...
          output = client.post(prepared_data)
          ...
          yield output
        
def process_data(
    data_df: pyspark.sql.DataFrame,
) -> pyspark.sql.DataFrame:
    ...
    ... = call_api(...)
    ...

In this example we have a code that calls external API in _partition_run, we do not want to call the actual API in our test, thus we want to mock ApiSessionClient.

from pytest import fixture

def _mocked_session_post(json_data: dict):
    ...
    return output


@fixture
def api_session_client(mocker):
    api_session_client_mock = mocker.patch.object(
        myjobpackage.processing,
        'ApiSessionClient',
    )
    api_session_client_mock.return_value = session_client = mocker.Mock()
    session_client.__enter__ = mocker.Mock()
    session_client.__enter__.return_value = session_client_ctx = mocker.Mock()
    session_client.__exit__ = mocker.Mock()
    session_client_ctx.post = mocker.Mock(side_effect=_mocked_session_post)
    return session_client

As ApiSessionClient is created inside rdd.mapPartitions we need to also mock call_api.

def _mocked_call_api(
    data_df: pyspark.sql.DataFrame,
) -> pyspark.sql.DataFrame:
    results = list(_partition_run(data_df.collect()))
    spark = SparkSession.builder.getOrCreate()
    pandas_df = pd.DataFrame(results)
    return spark.createDataFrame(pandas_df)


@fixture
def call_api_mock(mocker, api_session_client):
    mocker.patch.object(
        myjobpackage.processing, 'call_api', _mocked_call_api
    )

Then we can run the test with the mocked API.

def test_process_data(
    spark: SparkSession,
    call_api_mock,
):
  ...

Limitations

Map Key Type Must Be String

Although Spark supports non-string key types in map fields, the JSON format itself does not support non-string keys. In JSON, all keys are inherently interpreted as strings, regardless of their declared type in the schema. This discrepancy becomes problematic when testing with .ndjson files.

Specifically, if the schema defines a map key type as anything other than string (such as long or integer), the reinitialization of the metastore will result in None values for all fields in the Delta table when the data is loaded. This happens because the keys in the JSON data are read as strings, but the schema expects another type, leading to a silent failure where no exception or warning is raised. This makes the issue difficult to detect and debug.

License

pysparkdt is licensed under the MIT license. See the LICENSE file for more details.

How to Contribute

See CONTRIBUTING.md.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pysparkdt-1.1.0.tar.gz (9.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pysparkdt-1.1.0-py3-none-any.whl (9.9 kB view details)

Uploaded Python 3

File details

Details for the file pysparkdt-1.1.0.tar.gz.

File metadata

  • Download URL: pysparkdt-1.1.0.tar.gz
  • Upload date:
  • Size: 9.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for pysparkdt-1.1.0.tar.gz
Algorithm Hash digest
SHA256 f9b15e4756ba425bcc978c440c148604bd9e2bfea45cec67db129258b7d258ce
MD5 7ea0952cc736a3e102f1d6e2201cb9b7
BLAKE2b-256 3f970695aade8425c125bffdea0abc20cf54daf3c1dce369d3dc448357b94c4f

See more details on using hashes here.

Provenance

The following attestation bundles were made for pysparkdt-1.1.0.tar.gz:

Publisher: release.yaml on datamole-ai/pysparkdt

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pysparkdt-1.1.0-py3-none-any.whl.

File metadata

  • Download URL: pysparkdt-1.1.0-py3-none-any.whl
  • Upload date:
  • Size: 9.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for pysparkdt-1.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 47ef9847b8443a2e594a09c5253c53ed8b3947a401f0e18865d2dfc19081f57e
MD5 6588f498781a57aef2d50e7e035626a2
BLAKE2b-256 6b74918c96cfd2fa4eb453b86d4d9ec32fc6c5c466469392ee6f40dda6d1dca7

See more details on using hashes here.

Provenance

The following attestation bundles were made for pysparkdt-1.1.0-py3-none-any.whl:

Publisher: release.yaml on datamole-ai/pysparkdt

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page