Skip to main content

High-performance single-row inference compiler for scikit-learn pipelines with 2-10x speedup

Project description

Stripje - Make sklearn pipelines lean

Python 3.9+ License: MIT Build Status

Speed up your scikit-learn pipelines for single-row predictions by 2-10x!

Stripje is a high-performance compiler that converts trained scikit-learn pipelines into optimized Python functions, eliminating numpy overhead for single-row inference.

🚀 Why Stripje?

  • ⚡ 2-200x faster single-row predictions, depending on the pipeline complexity
  • 🔧 Drop-in replacement - works with your existing pipelines
  • 🎯 Zero configuration - just compile and use
  • 🛠️ Production ready - optimized for real-time inference

📦 Installation

pip install stripje

Or with uv (recommended):

uv add stripje

⚡ Quick Start

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from stripje import compile_pipeline

# 1. Create and fit your pipeline as usual
pipeline = Pipeline([
    ('scaler', StandardScaler()),
    ('classifier', LogisticRegression())
])
pipeline.fit(X_train, y_train)

# 2. Compile for fast single-row inference
fast_predict = compile_pipeline(pipeline)

# 3. Get predictions up to 10x faster!
test_row = [1.2, -0.5, 0.8, -1.1]
prediction = fast_predict(test_row)  # Much faster than pipeline.predict([test_row])

🎯 The Problem We Solve

Standard scikit-learn pipelines are slow for single predictions because they're optimized for batch processing. When you need to predict one row at a time (like in web APIs), numpy operations create unnecessary overhead.

Stripje compiles your trained pipeline into a specialized function that:

  • ✅ Extracts fitted parameters once
  • ✅ Eliminates array creation overhead
  • ✅ Uses native Python operations
  • ✅ Maintains identical results

📊 Performance Comparison

import time
from sklearn.datasets import make_classification
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression

# Setup
X, y = make_classification(n_samples=1000, n_features=20)
pipeline = Pipeline([('scaler', StandardScaler()), ('clf', LogisticRegression())])
pipeline.fit(X, y)
fast_predict = compile_pipeline(pipeline)

test_row = X[0].tolist()

# Benchmark single-row predictions
def benchmark_standard():
    start = time.time()
    for _ in range(1000):
        pipeline.predict([test_row])
    return time.time() - start

def benchmark_compiled():
    start = time.time()
    for _ in range(1000):
        fast_predict(test_row)
    return time.time() - start

standard_time = benchmark_standard()
compiled_time = benchmark_compiled()
speedup = standard_time / compiled_time

print(f"Standard pipeline: {standard_time:.3f}s")
print(f"Compiled pipeline: {compiled_time:.3f}s")
print(f"Speedup: {speedup:.1f}x faster!")

🔧 Supported Components

Stripje supports the most commonly used scikit-learn components:

🔄 Transformers

  • Scalers: StandardScaler, MinMaxScaler, RobustScaler, MaxAbsScaler
  • Encoders: OneHotEncoder, OrdinalEncoder, LabelEncoder
  • Other: Normalizer, QuantileTransformer, SelectKBest

🎯 Estimators

  • Classification: LogisticRegression, RandomForestClassifier, DecisionTreeClassifier, GaussianNB
  • Regression: LinearRegression

🏗️ Composite

  • ColumnTransformer - Full support with nested compilation

More components coming soon! See Contributing to request or add support.

📖 More Examples

Complex Pipeline with ColumnTransformer

from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.ensemble import RandomForestClassifier

# Create a complex pipeline
preprocessor = ColumnTransformer([
    ('num', StandardScaler(), ['age', 'income']),
    ('cat', OneHotEncoder(), ['category', 'region'])
])

pipeline = Pipeline([
    ('preprocessor', preprocessor),
    ('classifier', RandomForestClassifier(n_estimators=10))
])

# Fit and compile
pipeline.fit(X_train, y_train)
fast_predict = compile_pipeline(pipeline)

# Single-row prediction
row = [25, 50000, 'A', 'North']  # [age, income, category, region]
prediction = fast_predict(row)

Real-World API Usage

from flask import Flask, request, jsonify
import joblib

app = Flask(__name__)

# Load and compile your model once at startup
model = joblib.load('trained_pipeline.pkl')
fast_predict = compile_pipeline(model)

@app.route('/predict', methods=['POST'])
def predict():
    data = request.json['features']
    prediction = fast_predict(data)  # Super fast!
    return jsonify({'prediction': prediction.tolist()})

🚫 Limitations

  • Input must be lists/arrays (no pandas DataFrames directly)
  • No sparse matrix support
  • Some transformers use approximations (e.g., QuantileTransformer)
  • Only listed components are supported

📚 API Reference

compile_pipeline(pipeline)

Compiles a fitted scikit-learn pipeline into a fast single-row prediction function.

Args:

  • pipeline: A fitted scikit-learn Pipeline

Returns:

  • Function that takes a single row (list/array) and returns predictions

Raises:

  • ValueError: If pipeline contains unsupported components

get_supported_transformers()

Returns list of all supported transformer/estimator classes.

📁 Examples & Benchmarks

Check out the examples/ directory for:

  • simple_example.py - Basic usage
  • benchmark.py - Performance comparisons
  • comprehensive_benchmark.py - Detailed benchmarks
  • profiler_demo.py - Profiling tools

🔌 Extending Support

Want to add support for a new transformer? It's easy:

from stripje import register_step_handler

@register_step_handler(YourTransformer)
def handle_your_transformer(step):
    # Extract parameters from the fitted step
    param1 = step.param1_
    param2 = step.param2_

    def transform_one(x):
        # Implement single-row transformation logic
        result = []
        for val in x:
            # Your transformation logic here
            transformed_val = val * param1 + param2
            result.append(transformed_val)
        return result

    return transform_one

🤝 Contributing

Contributions are welcome! Please feel free to submit a pull request or open an issue.

  1. Fork the repository
  2. Create a feature branch (git checkout -b feature/amazing-feature)
  3. Commit your changes (git commit -m 'Add some amazing feature')
  4. Push to the branch (git push origin feature/amazing-feature)
  5. Open a Pull Request

🛠️ Development

Setup Development Environment

  1. Clone the repository:
git clone https://github.com/hadi-gharibi/stripje.git
cd stripje
  1. Install all dependencies (including optional ones for full testing):
uv sync --all-extras
  1. Install pre-commit hooks:
uv run pre-commit install

Code Quality Tools

This project uses modern Python development tools:

  • Ruff: Fast linting, formatting, and import sorting
  • MyPy: Static type checking
  • pre-commit: Automated code quality checks

Run code quality checks:

# Lint and auto-fix issues
uv run ruff check src/ tests/ --fix

# Format code
uv run ruff format src/ tests/

# Type checking
uv run mypy src/

# Run all pre-commit hooks
uv run pre-commit run --all-files

Testing

Run tests:

uv run pytest

Run tests with coverage:

uv run pytest --cov=stripje

📄 License

This project is licensed under the MIT License - see the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

stripje-0.1.0.tar.gz (52.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

stripje-0.1.0-py3-none-any.whl (30.5 kB view details)

Uploaded Python 3

File details

Details for the file stripje-0.1.0.tar.gz.

File metadata

  • Download URL: stripje-0.1.0.tar.gz
  • Upload date:
  • Size: 52.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for stripje-0.1.0.tar.gz
Algorithm Hash digest
SHA256 b7a01720c99cd3868b3ed4436443dc3d83a0965a06e76b3e2f19a0819218a9a2
MD5 0143ac4d27c0616b3d280b44c5fb4baf
BLAKE2b-256 a8ae013436a0ac2b2f036a00f0a973999742fc12dac5e9bcc4e8f85e9dd798f7

See more details on using hashes here.

Provenance

The following attestation bundles were made for stripje-0.1.0.tar.gz:

Publisher: release.yml on hadi-gharibi/stripje

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file stripje-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: stripje-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 30.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for stripje-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ee4b17446551b244284a34a75223e0d58f50f5348d948caf893b64562d134c60
MD5 3ffcb14ace90215dd28fe9650ab1ec54
BLAKE2b-256 3c8cdd13d838d98129ecef1a4a25b04f838abff2cde5f3a845533e218ae0c3aa

See more details on using hashes here.

Provenance

The following attestation bundles were made for stripje-0.1.0-py3-none-any.whl:

Publisher: release.yml on hadi-gharibi/stripje

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page