LLM-powered estimators for scikit-learn pipelines
Project description
promptlearn
promptlearn brings large language models into your scikit-learn workflow. It is able to look at data, reason about the meaning of inputs and outputs, relate it to and identify relevant knowledge of the world, automatically building standalone executable Python code that augments the relationships of the original data with relevant materialized world-knowledge about categorical variables.
📊 Outperforming Traditional Models with Built-In Knowledge
Consider a simple binary classification task: predicting whether an animal is a mammal given things like its name, weight, and lifespan (python examples/quickstart.py --demo compare --dataset mammal).
Traditional models depend solely on the input features. But promptlearn models can use their internal understanding of zoology to form highly accurate rules, pulling in data about known mammals, and making that knowledge available in explicit reference tables for subsequent predictions.
| model | accuracy (higher is better) | fit_time_sec | predict_time_sec |
|---|---|---|---|
| promptlearn_o3-mini | 0.94 | 49.11 | 0.0028 |
| promptlearn_o4-mini | 0.86 | 60.96 | 0.0024 |
| promptlearn_gpt-3.5-turbo | 0.66 | 20.25 | 0.0027 |
| promptlearn_gpt-4o | 0.66 | 43.93 | 0.0023 |
| logistic_regression | 0.60 | 0.02 | 0.0010 |
| decision_tree | 0.53 | 0.0014 | 0.0005 |
| gradient_boosting | 0.53 | 0.02 | 0.0011 |
| promptlearn_gpt-4 | 0.40 | 12.49 | 0.0022 |
| dummy | 0.34 | 0.0006 | 0.0001 |
| random_forest | 0.28 | 0.01 | 0.0017 |
This type of semantic generalization is a powerful advantage for LLM-backed models.
Now compare performance on a regression task where the data contains samples of objects falling from different heights, under different gravity (python examples/quickstart.py --demo compare --dataset fall). This is a classic physics problem, with a well-known equation:
fall_time_s = sqrt((2 * height_m) / gravity_mps2)
promptlearn estimators are able to recover this exact formula, using just the dataframe itself, and use it to generate perfect predictions:
| model | mse (lower is better) | fit_time_sec | predict_time_sec |
|---|---|---|---|
| promptlearn_gpt-4o | 0.000 | 2.92 | 0.001 |
| promptlearn_o3-mini | 0.000 | 10.80 | 0.001 |
| promptlearn_o4-mini | 0.000 | 7.96 | 0.001 |
| random_forest | 0.028 | 0.01 | 0.002 |
| gradient_boosting | 0.035 | 0.01 | 0.001 |
| decision_tree | 0.067 | 0.001 | 0.000 |
| linear_regression | 0.498 | 0.001 | 0.000 |
| dummy | 5.273 | 0.001 | 0.000 |
| promptlearn_gpt-3.5-turbo | 18.193 | 3.01 | 0.002 |
| promptlearn_gpt-4 | 855.445 | 2.43 | 0.001 |
No feature engineering was performed. No physics constants were added. The model discovered the rule and applied it directly. Classical regressors, by contrast, approximated a curve but missed the exact structure.
These results highlight the practical benefit of reasoning models: they learn compact, expressive heuristics and can outperform traditional systems when symbolic insight or background knowledge is essential.
🤖 Estimators Powered by Language
promptlearn provides scikit-learn-compatible estimators that use LLMs as the modeling engine:
PromptClassifier– for predicting classes through generalized reasoningPromptRegressor– for modeling numeric relationships in data
These estimators follow the same API as other scikit-learn models (fit, predict, score) but operate via dynamic prompt construction and few-shot abstraction.
🚀 Try It
Everything runnable lives in a single guided tour, examples/quickstart.py — a menu of self-contained demos. Each makes live LLM calls, so run them one at a time:
python examples/quickstart.py --list # see all the demos
python examples/quickstart.py --demo zero_row # fit on column names only
python examples/quickstart.py --demo titanic --dump artifacts/ # deep tour: generated code, explain(), joblib
python examples/quickstart.py --demo compare --dataset mammal # promptlearn vs sklearn/XGBoost
The demos cover zero-row fitting, .sample(), joblib round-tripping, world-knowledge reasoning, linear/nonlinear/multi-output regression, XOR, GridSearchCV tuning, a large real OpenML dataset, the side-by-side model compare, and the deep titanic walkthrough (generated predict() code, explain(), and artifact dumping).
The compare demo is powered by the reusable promptlearn.compare_models(models, X_train, y_train, X_test, y_test) helper, which works with any mix of promptlearn and sklearn/XGBoost estimators.
🔌 Choose Your Provider
The LLM provider is selected by the model string and resolved via LiteLLM, so you are not locked into OpenAI:
PromptClassifier(model="gpt-5.5") # OpenAI (the default)
PromptClassifier(model="claude-sonnet-4-6") # Anthropic
PromptClassifier(model="ollama:llama3.1") # local Ollama
API keys are read from the usual per-provider environment variables (OPENAI_API_KEY, ANTHROPIC_API_KEY, …); local providers like Ollama need none.
To change the default model without touching code, set PROMPTLEARN_MODEL (e.g. export PROMPTLEARN_MODEL=gpt-5.4-mini for faster, cheaper runs). An explicit model= argument always takes precedence.
🕳 Zero-Example Learning
If you call .fit() with no rows — just column names — promptlearn will still return a working model.
This is possible because the LLM can hallucinate a plausible mapping based on:
- Column names
- Prior knowledge
- Type hints or value patterns
This makes rapid prototyping and conceptual modeling trivial.
🧪 Native .sample() Support
You can generate synthetic rows directly from any trained model using .sample(n):
>>> model.sample(3)
fruit is_citrus
Lime 1
Banana 0
Orange 1
This is useful for:
- Understanding what the model believes
- Creating test sets or bootstrapped data
- Building readable examples from internal logic
🔎 Explain the Learned Rule
Call .explain() to get a plain-English description of the heuristic the model
learned — useful for interpretability reporting:
>>> explanation = model.explain()
>>> print(explanation)
Predicts 1 (adult) when `age` is at least 18, otherwise 0.
>>> explanation.features_used
['age']
explain() returns an Explanation object with meta and data dicts (keys
also reachable as attributes) that is JSON round-trippable via to_json() /
Explanation.from_json(...). A bare explain() describes the whole model
(global, and cached so it's deterministic); passing a single row, explain(X),
describes that one prediction (local).
💾 Save and Reload with joblib
Like any scikit-learn model, promptlearn estimators can be serialized:
import joblib
joblib.dump(model, "model.joblib")
model = joblib.load("model.joblib")
The compiled prediction function is excluded from the saved file and recompiled on load. The heuristic remains intact, interpretable, and ready to use.
📚 Related Work
Scikit-LLM
Scikit-LLM provides zero- and few-shot classification through template-based prompting.
It is lightweight and NLP-focused.
promptlearn offers a broader modeling philosophy:
| Capability | Scikit-LLM | promptlearn |
|---|---|---|
| Produces runnable Python code | ❌ No | ✅ Yes |
| Regression support | ❌ No | ✅ Yes |
🛠 Development
Install the dev dependencies and enable the git hooks:
pip install -r requirements-dev.txt
pre-commit install
The pre-commit hooks run black and the full
test suite, and both must pass before a commit is allowed. Note the test suite
makes live LLM calls, so it needs a provider API key (e.g. OPENAI_API_KEY).
Bypass the hooks in an emergency with git commit --no-verify.
📁 License
MIT © 2025 Fredrik Linaker
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file promptlearn-0.4.1.tar.gz.
File metadata
- Download URL: promptlearn-0.4.1.tar.gz
- Upload date:
- Size: 36.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
14ae4e7922022d617112a0df7a9db172626790c0e897b09a2facc4a89697bfc1
|
|
| MD5 |
a3811ee1b5bd540d84a52391f5de849b
|
|
| BLAKE2b-256 |
34f3d41c300aa54f39f5a5954a7250854b3b16ea5784ff4359aebff1d8b5082c
|
File details
Details for the file promptlearn-0.4.1-py3-none-any.whl.
File metadata
- Download URL: promptlearn-0.4.1-py3-none-any.whl
- Upload date:
- Size: 26.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
021fc19e4523206b4512320153c245e5f71858a0b93c96a8cd152ec2bc5c49bc
|
|
| MD5 |
1eb692f6fb75d6969d68deeef2c49901
|
|
| BLAKE2b-256 |
6180f6f11398f5207ed29b70e5449d8c6861ac3b625a3a32b5b939c6645c7bef
|