Skip to main content

Allow SKLearn predictions to run on database systems in pure SQL.

Project description

orbital

Convert SKLearn pipelines into SQL queries for execution in a database without the need for a Python environment.

See examples directory for example pipelines and Documentation

Warning:

This is a work in progress.
You might encounter bugs or missing features.

Note:

Not all transformations and models can be represented as SQL queries,
so orbital might not be able to implement the specific pipeline you are using.

Getting Started

Install orbital:

$ pip install orbital

Prepare some data:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

COLUMNS = ["sepal.length", "sepal.width", "petal.length", "petal.width"]

iris = load_iris(as_frame=True)
iris_x = iris.data.set_axis(COLUMNS, axis=1)

# SQL and orbital don't like dots in column names, replace them with underscores
iris_x.columns = COLUMNS = [cname.replace(".", "_") for cname in COLUMNS]

X_train, X_test, y_train, y_test = train_test_split(
    iris_x, iris.target, test_size=0.2, random_state=42
)

Define a Scikit-Learn pipeline and train it:

from sklearn.compose import ColumnTransformer
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

pipeline = Pipeline(
    [
        ("preprocess", ColumnTransformer([("scaler", StandardScaler(with_std=False), COLUMNS)],
                                        remainder="passthrough")),
        ("linear_regression", LinearRegression()),
    ]
)
pipeline.fit(X_train, y_train)

Convert the pipeline to orbital:

import orbital
import orbital.types

orbital_pipeline = orbital.parse_pipeline(pipeline, features={
    "sepal_length": orbital.types.DoubleColumnType(),
    "sepal_width": orbital.types.DoubleColumnType(),
    "petal_length": orbital.types.DoubleColumnType(),
    "petal_width": orbital.types.DoubleColumnType(),
})

You can print the pipeline to see the result:

>>> print(orbital_pipeline)

ParsedPipeline(
    features={
        sepal_length: DoubleColumnType()
        sepal_width: DoubleColumnType()
        petal_length: DoubleColumnType()
        petal_width: DoubleColumnType()
    },
    steps=[
        merged_columns=Concat(
            inputs: sepal_length, sepal_width, petal_length, petal_width,
            attributes: 
             axis=1
        )
        variable1=Sub(
            inputs: merged_columns, Su_Subcst=[5.809166666666666, 3.0616666666666665, 3.7266666666666666, 1.18333333...,
            attributes: 
        )
        multiplied=MatMul(
            inputs: variable1, coef=[-0.11633479416518255, -0.05977785171980231, 0.25491374699772246, 0.5475959...,
            attributes: 
        )
        resh=Add(
            inputs: multiplied, intercept=[0.9916666666666668],
            attributes: 
        )
        variable=Reshape(
            inputs: resh, shape_tensor=[-1, 1],
            attributes: 
        )
    ],
)

Now we can generate the SQL from the pipeline:

sql = orbital.export_sql("DATA_TABLE", orbital_pipeline, dialect="duckdb")

And check the resulting query:

>>> print(sql)

SELECT ("t0"."sepal_length" - 5.809166666666666) * -0.11633479416518255 + 0.9916666666666668 +  
       ("t0"."sepal_width" - 3.0616666666666665) * -0.05977785171980231 + 
       ("t0"."petal_length" - 3.7266666666666666) * 0.25491374699772246 + 
       ("t0"."petal_width" - 1.1833333333333333) * 0.5475959809777828 
AS "variable" FROM "DATA_TABLE" AS "t0"

Once the SQL is generate, you can use it to run the pipeline on a database. From here on the SQL can be exported and reused in other places:

>>> print("\nPrediction with SQL")
>>> duckdb.register("DATA_TABLE", X_test)
>>> print(duckdb.sql(sql).df()["variable"][:5].to_numpy())

Prediction with SQL
[ 1.23071715 -0.04010441  2.21970287  1.34966889  1.28429336]

We can verify that the prediction matches the one done by Scikit-Learn by running the scikitlearn pipeline on the same set of data:

>>> print("\nPrediction with SciKit-Learn")
>>> print(pipeline.predict(X_test)[:5])

Prediction with SciKit-Learn
[ 1.23071715 -0.04010441  2.21970287  1.34966889  1.28429336 ]

Supported Models

orbital currently supports the following models:

  • Linear Regression
  • Logistic Regression
  • Lasso Regression
  • Elastic Net
  • Decision Tree Regressor
  • Decision Tree Classifier
  • Random Forest Classifier
  • Gradient Boosting Regressor
  • Gradient Boosting Classifier

Testing

Setup testing environment:

$ uv sync --no-dev --extra test

Run Tests:

$ uv run pytest -v

Try Examples:

$ uv run examples/pipeline_lineareg.py

Development

Setup a development environment:

$ uv sync

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

orbital-0.2.2.tar.gz (43.3 kB view details)

Uploaded Source

Built Distribution

orbital-0.2.2-py3-none-any.whl (53.1 kB view details)

Uploaded Python 3

File details

Details for the file orbital-0.2.2.tar.gz.

File metadata

  • Download URL: orbital-0.2.2.tar.gz
  • Upload date:
  • Size: 43.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.11

File hashes

Hashes for orbital-0.2.2.tar.gz
Algorithm Hash digest
SHA256 68f5151fcafcaf6246b8e51b8d8e55d7aabf01bb88b76383deeb86ce804478e7
MD5 a56a35b7e01cb0bc64e95d1d6136943f
BLAKE2b-256 8e97928a9aab010ce7607ed4aee7fd93ab62722de0e65dd3123b9097448351fe

See more details on using hashes here.

File details

Details for the file orbital-0.2.2-py3-none-any.whl.

File metadata

  • Download URL: orbital-0.2.2-py3-none-any.whl
  • Upload date:
  • Size: 53.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.11

File hashes

Hashes for orbital-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 1331af816c51f6fcd1027ad0069a5f5f0b358c0f8e4f445400876f0a685ddf40
MD5 c093b31fbb34ea5d17b58fe2214e6469
BLAKE2b-256 01495b9e6b47d508c7497877f62619abb8e795a1eca390e0fe246ca2fcc1ea2e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page