Skip to main content

Convert scikit-learn models and pipelines into executable SQL.

Project description

sqlearn

sqlearn is a Python library designed to convert Scikit-Learn models and pipelines into native SQL queries. This allows you to run machine learning inference directly within your database server (e.g., DuckDB, Postgres, etc.) without extracting data or managing separate inference services.

Project Purpose

The goal of this project is to implement a generic converter that translates trained sklearn objects into SQL statements. This approach enables:

  • Zero-latency inference: Run predictions where your data lives.
  • Simplified architecture: Remove the need for pickle files and separate ML microservices.
  • Portability: Generate SQL that can be executed on various SQL dialects (powered by sqlglot).

Currently, we support:

  • Linear Models (LinearRegression, Ridge, Lasso, ElasticNet, SGDRegressor) into SQL arithmetic expressions or CASE statements.
  • StandardScaler preprocessing.
  • OneHotEncoder preprocessing (handle_unknown="ignore", drop=None).
  • ColumnTransformer combining numeric and categorical branches.
  • Pipeline with ColumnTransformer + linear regressor as final estimator.
  • DecisionTreeClassifier and RandomForestClassifier.

Getting Started

Prerequisites

  • Python >= 3.12
  • uv for dependency management

Installation

pip install sqlearn

For local development:

uv pip install -e .

Running Tests

We use pytest for testing. You can run the test suite using the configured script in pyproject.toml:

uv run test

Or directly via pytest:

uv run pytest tests/ -v

Development Rules

  1. Test First: Always add tests for new features or bug fixes.
  2. Integrity: Do not modify existing tests just to make them pass. If a test fails, fix the implementation, not the test (unless the test itself is incorrect).
  3. Verification: Ensure all tests pass before committing or submitting changes.
    uv run test
    
  4. Usage Examples: When adding new modules, include an if __name__ == "__main__": block with a runnable example to verify functionality quickly and make sure you can actually run it;
  5. MAKE SURE ALL TESTS PASS NOT JUST THE NEW ONES

Examples

Linear Model to SQL

import numpy as np
from sklearn.linear_model import LinearRegression
from sqlearn.linear_model import LinearModelConverter

# Use fixed coefficients so the output SQL is deterministic
model = LinearRegression()
model.coef_ = np.array([2.0, -3.0])
model.intercept_ = 5.0

converter = LinearModelConverter(model)
sql = converter.to_sql(feature_names=["col1", "col2"], table_name="my_table")

print(sql)
# SELECT (2 * col1) + (-3 * col2) + 5 AS prediction FROM my_table

Pipeline to SQL (StandardScaler + OneHotEncoder + LinearRegression)

import numpy as np
import pandas as pd
from sklearn.compose import ColumnTransformer
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler

from sqlearn.pipeline import PipelineConverter

X = pd.DataFrame({"age": [0, 2], "city": ["la", "ny"]})

preprocessor = ColumnTransformer(
    [
        ("num", StandardScaler(), ["age"]),
        ("cat", OneHotEncoder(handle_unknown="ignore", sparse_output=False), ["city"]),
    ]
).fit(X)

model = LinearRegression()
model.coef_ = np.array([2.0, 10.0, -5.0])  # num__age, cat__city_la, cat__city_ny
model.intercept_ = 1.0

pipe = Pipeline([("preprocessor", preprocessor), ("model", model)])

sql = PipelineConverter(pipe).to_sql(feature_names=["age", "city"], table_name="people")

print(sql)
# WITH transformed AS (
#   SELECT age - 1 AS num__age,
#          CASE WHEN city = 'la' THEN 1 ELSE 0 END AS cat__city_la,
#          CASE WHEN city = 'ny' THEN 1 ELSE 0 END AS cat__city_ny
#   FROM people
# )
# SELECT (2 * num__age) + (10 * cat__city_la) + (-5 * cat__city_ny) + 1 AS prediction
# FROM transformed

DecisionTreeClassifier to SQL

import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sqlearn.tree_model import DecisionTreeClassifierConverter

X = np.array([[0.0], [1.0], [2.0], [3.0]])
y = np.array([0, 0, 1, 1])
clf = DecisionTreeClassifier(max_depth=1, random_state=42).fit(X, y)

sql = DecisionTreeClassifierConverter(clf).to_sql(
    feature_names=["x0"],
    table_name="input_data",
)

print(sql)
# SELECT CASE WHEN (CASE WHEN x0 <= 1.5 THEN 1 ELSE 0 END) >= ...
# ... THEN 0 ELSE 1 END AS prediction FROM input_data

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

scikit_sqlearn-0.1.0.tar.gz (18.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

scikit_sqlearn-0.1.0-py3-none-any.whl (17.3 kB view details)

Uploaded Python 3

File details

Details for the file scikit_sqlearn-0.1.0.tar.gz.

File metadata

  • Download URL: scikit_sqlearn-0.1.0.tar.gz
  • Upload date:
  • Size: 18.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for scikit_sqlearn-0.1.0.tar.gz
Algorithm Hash digest
SHA256 04d6750937cd9103f74ada0b855bfa9775dd117254d198acc81ddcfab940093e
MD5 18f5ddf745572824afafe601caf7feab
BLAKE2b-256 9d1959c0e37fde918b26c3720f924254c12397ce4de9da3be854fc96d31c81a7

See more details on using hashes here.

Provenance

The following attestation bundles were made for scikit_sqlearn-0.1.0.tar.gz:

Publisher: publish-pypi.yml on sofeikov/sqlearn

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file scikit_sqlearn-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: scikit_sqlearn-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 17.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for scikit_sqlearn-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 8d00e8df4f3c380904896260afb49bcd1ee065088e226cba8eedb4c2f54e222c
MD5 564c3208af47259e08703c9e36a5288f
BLAKE2b-256 bc1b2ef7b6ada3ef96741370b486fcb949baea22647af4d697e2b60ceb0c5a03

See more details on using hashes here.

Provenance

The following attestation bundles were made for scikit_sqlearn-0.1.0-py3-none-any.whl:

Publisher: publish-pypi.yml on sofeikov/sqlearn

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page