Skip to main content

A PySpark transform registry with MLflow integration. Requires Java 17+ with security manager enabled.

Project description

PySpark Transform Registry

A simplified library for registering and loading PySpark transform functions using MLflow's model registry.

Installation

pip install pyspark-transform-registry

Quick Start

Register a Function

from pyspark_transform_registry import register_function
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, lit

def clean_data(df: DataFrame) -> DataFrame:
    """Remove invalid records and standardize data."""
    return df.filter(col("amount") > 0).withColumn("status", lit("clean"))

# Register the function
logged_model = register_function(
    func=clean_data,
    name="analytics.etl.clean_data",
    description="Data cleaning transformation"
)

Load and Use a Function

from pyspark_transform_registry import load_function

# Load the registered function
clean_data_func = load_function("analytics.etl.clean_data", version=1)

# Use it on your data
result = clean_data_func(your_dataframe)

Features

  • Simple API: Just two main functions - register_function() and load_function()
  • Direct Registration: Register functions directly from Python code
  • File-based Registration: Load and register functions from Python files
  • Automatic Versioning: Integer-based versioning with automatic incrementing
  • MLflow Integration: Built on MLflow's model registry with automatic dependency inference
  • 3-Part Naming: Supports hierarchical naming (catalog.schema.table)
  • Runtime Validation: Automatic schema inference and DataFrame validation before execution
  • Type Safety: Validate input DataFrames against inferred schema constraints
  • Flexible Validation: Support for both strict and permissive validation modes
  • Source Code Inspection: Access original function source code and metadata for debugging

Usage Examples

Direct Function Registration

from pyspark_transform_registry import register_function
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, when

def risk_scorer(df: DataFrame, threshold: float = 100.0) -> DataFrame:
    """Calculate risk scores based on amount."""
    return df.withColumn(
        "risk_score",
        when(col("amount") > threshold, "high").otherwise("low")
    )

# Register with metadata
register_function(
    func=risk_scorer,
    name="finance.scoring.risk_scorer",
    description="Risk scoring transformation",
    extra_pip_requirements=["numpy>=1.20.0"],
    tags={"team": "finance", "category": "scoring"}
)

File-based Registration

# transforms/data_processors.py
from pyspark.sql import DataFrame
from pyspark.sql.functions import col

def feature_engineer(df: DataFrame) -> DataFrame:
    """Create engineered features."""
    return df.withColumn("feature_1", col("amount") * 2)

def data_validator(df: DataFrame) -> DataFrame:
    """Validate data quality."""
    return df.filter(col("amount").isNotNull())
# Register from file
register_function(
    file_path="transforms/data_processors.py",
    function_name="feature_engineer",
    name="ml.features.feature_engineer",
    description="Feature engineering pipeline"
)

Loading and Versioning

from pyspark_transform_registry import load_function

# Load latest version
transform = load_function("finance.scoring.risk_scorer", version=1)

# Load specific version
transform_v2 = load_function("finance.scoring.risk_scorer", version=2)

# Use MLflow's native model registry APIs to discover models
# See MLflow documentation for model discovery patterns

Runtime Validation

The registry automatically infers schema constraints from your functions and validates input DataFrames before execution.

from pyspark_transform_registry import register_function, load_function
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, lit

def process_orders(df: DataFrame) -> DataFrame:
    """Process order data with specific column requirements."""
    return (df
        .filter(col("amount") > 0)
        .withColumn("processed", lit(True))
        .select("order_id", "customer_id", "amount", "processed")
    )

# Register with automatic schema inference
register_function(
    func=process_orders,
    name="retail.processing.process_orders",
    infer_schema=True  # Default: True
)

# Load with validation enabled (default)
transform = load_function("retail.processing.process_orders", version=1)

# This will validate the DataFrame structure before processing
result = transform(orders_df)  # Validates: order_id, customer_id, amount columns exist

# Load with validation disabled
transform_no_validation = load_function(
    "retail.processing.process_orders",
    version=1,
    validate_input=False
)

# Load with strict validation (warnings become errors)
transform_strict = load_function(
    "retail.processing.process_orders",
    version=1,
    strict_validation=True
)

Multi-Parameter Functions with Validation

def filter_by_category(df: DataFrame, category: str, min_amount: float = 0.0) -> DataFrame:
    """Filter data by category and minimum amount."""
    return df.filter(
        (col("category") == category) &
        (col("amount") >= min_amount)
    )

sample_df = spark.createDataFrame([
    ("electronics", 100.0, "order_1"),
    ("books", 25.0, "order_2")
], ["category", "amount", "order_id"])

register_function(
    func=filter_by_category,
    name="retail.filtering.filter_by_category",
)

# Load and use with parameters
filter_func = load_function("retail.filtering.filter_by_category", version=1)

# Use with validation - validates DataFrame structure before filtering
electronics = filter_func(sample_df, params={"category": "electronics", "min_amount": 100.0})

Source Code Inspection

The loaded functions provide access to the original transform source code for debugging and understanding:

# Load a function
transform = load_function("retail.processing.process_orders", version=1)

# Get the original source code
source_code = transform.get_source()
print(source_code)  # Shows the original function definition

# Get the original function for advanced inspection
original_func = transform.get_original_function()
print(f"Function name: {original_func.__name__}")
print(f"Docstring: {original_func.__doc__}")

# Use inspect on the original function
import inspect
signature = inspect.signature(original_func)
print(f"Signature: {signature}")

# Note: inspect.getsource(transform) shows wrapper code
# transform.get_source() shows the original function code

API Reference

register_function()

Register a PySpark transform function in MLflow's model registry.

Parameters:

  • func (Callable, optional): The function to register (for direct registration)
  • name (str): Model name for registry (supports 3-part naming)
  • file_path (str, optional): Path to Python file containing the function
  • function_name (str, optional): Name of function to extract from file
  • description (str, optional): Model description
  • extra_pip_requirements (list, optional): Additional pip requirements
  • tags (dict, optional): Tags to attach to the registered model
  • infer_schema (bool, optional): Whether to automatically infer schema constraints (default: True)
  • schema_constraint (PartialSchemaConstraint, optional): Pre-computed schema constraint

Returns:

  • str: Model URI of the registered model

load_function()

Load a previously registered PySpark transform function with optional validation.

Parameters:

  • name (str): Model name in registry
  • version (int or str): Model version to load (required)
  • validate_input (bool, optional): Whether to validate input DataFrames against stored schema constraints (default: True)
  • strict_validation (bool, optional): If True, treat validation warnings as errors (default: False)

Returns:

  • Callable: The loaded transform function that supports both single and multi-parameter usage:
    • Single param: transform(df)
    • Multi param: transform(df, params={'param1': value1, 'param2': value2})
    • Source inspection: transform.get_source() - Returns the original function source code
    • Function access: transform.get_original_function() - Returns the unwrapped original function

Model Discovery

To discover registered models, use MLflow's native model registry APIs:

import mlflow
client = mlflow.tracking.MlflowClient()
models = client.list_registered_models()
for model in models:
    print(f"Model: {model.name}")
    for version in model.latest_versions:
        print(f"  Version: {version.version}")

Requirements

  • Python 3.11+
  • PySpark 3.0+
  • MLflow 2.22+

Development

# Install development dependencies
pip install -e ".[dev]"

# Run tests
pytest

# Run linting
ruff check --fix
ruff format

License

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pyspark_transform_registry-0.1.0.tar.gz (59.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pyspark_transform_registry-0.1.0-py3-none-any.whl (36.4 kB view details)

Uploaded Python 3

File details

Details for the file pyspark_transform_registry-0.1.0.tar.gz.

File metadata

File hashes

Hashes for pyspark_transform_registry-0.1.0.tar.gz
Algorithm Hash digest
SHA256 7d237ac90f5ab9e91d69d6687708bfae078c9ca52e0bfbce7e6c0d79f6d80ff0
MD5 4e42739dee29a0455c4f3d46a15eb0b7
BLAKE2b-256 324344d8093c77b746cae6a2488b327fcd22f3dc63e276ed650dc4ac2a6546e8

See more details on using hashes here.

File details

Details for the file pyspark_transform_registry-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for pyspark_transform_registry-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 39c06f7a2f42b2d99f33df11cc89fda608756e965ffe02ecdf8d91d43814deeb
MD5 4b8147e43561ebbad2492f11492e6f7a
BLAKE2b-256 56a30efa56a76de8b5d615b09c7a0d28114d97b761e9a71c3d689bbd32f8a9b9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page