Skip to main content

A python library to ease MLOps for Dataproc customers

Project description

Google Cloud logo

Dataproc ML

PyPI version License

Public Preview Disclaimer

Interfaces and functionality are subject to change. It is not recommended for production-critical applications without thorough testing and understanding of the potential risks.

Dataproc ML is a Python library that simplifies distributed ML inference on Google Cloud Dataproc. It provides high-level handlers to run PyTorch and Vertex AI Gemini models at scale using Apache Spark, without the complexity of manual model distribution and batch processing.

Installation

You can install the library using pip:

pip install dataproc-ml

Usage Examples

Here are a couple of examples demonstrating how to use the handlers for distributed inference on a Spark DataFrame.

Generative AI (Gemini) Model Inference

Note: Using the GenAiModelHandler involves making API calls to Vertex AI, which will incur costs. Please review the Vertex AI Generative AI pricing.

Use Google's Gemini models to perform generative tasks on your data. This example uses a prompt template to ask for the capital of countries listed in a Spark DataFrame.

from pyspark.sql import SparkSession
from google.cloud.dataproc_ml.inference import GenAiModelHandler

spark = SparkSession.builder.getOrCreate()

# Create a sample DataFrame
data = [("USA",), ("France",), ("Japan",)]
input_df = spark.createDataFrame(data, ["country"])

# The handler will automatically use the 'country' column
result_df = (
    GenAiModelHandler()
    .prompt("What is the capital of {country}?")
    .output_col("capital_city")
    .transform(input_df)
)

result_df.show()
# +-------+----------------+
# |country|capital_city    |
# +-------+----------------+
# |USA    |Washington, D.C.|
# |France |Paris           |
# |Japan  |Tokyo           |
# +-------+----------------+

PyTorch Model Inference

Run distributed inference using a pre-trained PyTorch model stored in Google Cloud Storage. This example assumes you have a Spark DataFrame input_df with a column named features containing image tensors or other numerical data.

from pyspark.sql import SparkSession
from google.cloud.dataproc_ml.inference import PyTorchModelHandler

spark = SparkSession.builder.getOrCreate()

data = [([0.1, 0.2, 0.3],), ([0.4, 0.5, 0.6],), ([0.7, 0.8, 0.9],)]
input_df = spark.createDataFrame(data, ["features"])

# Path to your saved PyTorch model in GCS
model_gcs_path = "gs://your-bucket/path/to/model.pt"

# Apply the model for inference
result_df = (
    PyTorchModelHandler()
    .model_path(model_gcs_path)
    .input_cols("features")
    .transform(input_df)
)

result_df.show()
# +------------------+--------------------+
# |          features|         predictions|
# +------------------+--------------------+
# |[0.1, 0.2, 0.3]   |[0.543, 0.457]      |
# |[0.4, 0.5, 0.6]   |[0.621, 0.379]      |
# |[0.7, 0.8, 0.9]   |[0.789, 0.211]      |
# +------------------+--------------------+

Documentation

For more detailed information on the available handlers and their configurations, please refer to our official documentation.

Contributing

Contributions are welcome! Please see contributing.md for details on how to set up your development environment, run linters/tests, etc.

License

This project is licensed under the Apache 2.0 License. See the LICENSE file for more details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dataproc_ml-1.0.0rc1.tar.gz (20.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dataproc_ml-1.0.0rc1-py3-none-any.whl (27.8 kB view details)

Uploaded Python 3

File details

Details for the file dataproc_ml-1.0.0rc1.tar.gz.

File metadata

  • Download URL: dataproc_ml-1.0.0rc1.tar.gz
  • Upload date:
  • Size: 20.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.8

File hashes

Hashes for dataproc_ml-1.0.0rc1.tar.gz
Algorithm Hash digest
SHA256 6c90e6235b5527223391b44d80f400724d4f1b0281c9d9bb4a6ea8801df13896
MD5 0ec186f5bf9fc6534fa0dbee3a51f501
BLAKE2b-256 3adc591420e9d1e4d60b845174ca6c184694f0f3c36721096073ac5b92c064e8

See more details on using hashes here.

File details

Details for the file dataproc_ml-1.0.0rc1-py3-none-any.whl.

File metadata

  • Download URL: dataproc_ml-1.0.0rc1-py3-none-any.whl
  • Upload date:
  • Size: 27.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.8

File hashes

Hashes for dataproc_ml-1.0.0rc1-py3-none-any.whl
Algorithm Hash digest
SHA256 6a3939f3fd79429f712e7a60c8c2711c349fc8ab6131054b2ec0c80e78286f86
MD5 221793092e8091e2f520c875ada56865
BLAKE2b-256 069e7cd86e55fc49f396723f397f408bfa665b5c0c9db93469fdeaffb3d8ede9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page