Skip to main content

Cross-Validated Decision Tree — scikit-learn-compatible estimators backed by a Rust core

Project description

CVDT for Python (scikit-learn compatible)

cvdt ships two estimators that follow the scikit-learn API and are backed by the zero-dependency Rust core:

  • cvdt.CVDTClassifierClassifierMixin, predict / predict_proba / predict_log_proba
  • cvdt.CVDTRegressorRegressorMixin, predict

Both are full BaseEstimators: they work with Pipeline, cross_val_score, GridSearchCV, clone, and get_params / set_params.

How the binding is layered

scikit-learn API  ──  python/cvdt/_estimator.py   (validation, tags, classes_, label encoding)
                          │  numpy arrays
native extension  ──  cvdt._cvdt  (src/python.rs, PyO3)   (numpy → flat Columns → tree)
                          │
Rust core         ──  DecisionTree<Task>  (zero-dependency crate)

All sklearn semantics live in Python; the Rust side stays a thin, fast core. The core crate remains zero-dependency — PyO3 and the numpy crate are pulled in only under the python cargo feature. A plain cargo build / cargo test still compiles nothing but std.

Building

Requires a Rust toolchain and maturin.

pip install maturin numpy scikit-learn
# dev build into the current environment:
maturin develop --release --features python
# or build a wheel:
maturin build --release --features python
pip install target/wheels/cvdt-*.whl

Wheels are built with PyO3's abi3-py38, so a single wheel works across CPython 3.8+.

Quickstart

import numpy as np
from cvdt import CVDTClassifier, CVDTRegressor

clf = CVDTClassifier(criterion="gini", cv_folds=5, max_depth=6)
clf.fit(X_train, y_train)
clf.predict(X_test)
clf.predict_proba(X_test)
clf.score(X_test, y_test)

reg = CVDTRegressor(criterion="mse", mode="fast", n_bins=32)
reg.fit(X_train, y_train)

Everything composes with sklearn:

from sklearn.model_selection import GridSearchCV
grid = GridSearchCV(
    CVDTClassifier(),
    {"max_depth": [4, 8], "aggregator": ["mean", "mean_minus_lambda_std"], "mode": ["strict", "fast"]},
    cv=5,
).fit(X, y)

See examples/sklearn_usage.py.

Parameters

Common to both estimators:

Parameter Default Meaning
max_depth 8 Max split levels; None for unlimited.
min_samples_split 2 Min samples for a node to be splittable.
min_samples_leaf 1 Min samples each child must receive.
min_impurity_decrease 0.0 Min aggregated gain to accept a split.
n_bins 8 Quantile bins for continuous features. With split_style="threshold" these boundaries are the candidate cut points, so a larger value gives finer thresholds.
cv_folds 5 K in the K-fold split evaluation.
cv_seed 42 Fold-shuffle seed (determinism).
cv_shuffle True Shuffle before folding.
mode "strict" "strict" (per-fold edges) or "fast" (histogram path).
split_style "threshold" Continuous split geometry: "threshold" (CART-style ordered cut, x < edge) or "bin" (single quantile bin vs rest). Categorical features are always tested by equality.
aggregator "mean" mean, median, trimmed_mean, signal_to_noise, mean_minus_lambda_std.
agg_frac 0.1 Trim fraction for trimmed_mean.
agg_eps 1e-12 Stabiliser for signal_to_noise.
agg_lambda 1.0 Std penalty for mean_minus_lambda_std.
parallel False Evaluate features in parallel.
n_threads 1 Worker threads when parallel.
parallel_min_samples 512 Only parallelise nodes at least this large.
categorical_features None Column indices to treat as categorical.

Estimator-specific:

  • CVDTClassifier.criterion: "gini" (default) or "entropy".
  • CVDTRegressor.criterion: "mse" (default), "variance", or "mae".

Objective-driven classification

Instead of an impurity proxy, CVDTClassifier can select splits to directly optimise a target metric on the held-out folds:

Parameter Default Meaning
objective None None (use criterion), or precision/recall/f1/fbeta/accuracy.
average "binary" binary, micro, macro, or weighted.
pos_label 1 Positive class (index into sorted classes_) when average="binary".
beta 1.0 β for objective="fbeta".
# Optimise recall of the positive class on an imbalanced problem:
clf = CVDTClassifier(objective="recall", average="binary", pos_label=1).fit(X, y)

# Macro-F1 for multiclass:
clf = CVDTClassifier(objective="f1", average="macro").fit(X, y)

When objective is set, criterion is ignored. Splits are accepted only when they improve the metric over making the node a leaf, so objective-mode trees are usually shallower and tuned to the metric; set min_impurity_decrease below 0 to allow non-improving splits.

Notes and caveats

  • Interpretability: a fitted CVDTClassifier/CVDTRegressor is a single tree, so est.get_depth() and est.get_n_leaves() (sklearn-style accessors) report its size; total nodes of the binary tree is 2 * get_n_leaves() - 1. The tree can be inspected and visualised:

    • est.export_text(feature_names=..., class_names=...) — readable rule dump.
    • est.export_graphviz(...) — Graphviz DOT string (render with graphviz, pydot, or dtreeviz).
    • est.get_tree() — the tree as a dict of parallel arrays for custom plots.

    How continuous branches read depends on split_style: with the default "threshold" they are single CART-style cuts (x[f] < edge), and with "bin" they are intervals (lo <= x[f] < hi). The "true" branch is the one that matches (missing values route to "false").

  • categorical_features columns are read as integer category ids. Feed already-integer-encoded values (e.g. from OrdinalEncoder); values are rounded to the nearest non-negative integer, and negatives / non-finite are treated as an unknown category (routed to the "not-in-state" child).

  • Missing values: NaN / non-finite continuous values are accepted and routed right, so allow_nan is set in the estimator tags.

  • mode="fast" + criterion="mae" is unsupported (the median has no additive sufficient statistic for the histogram path) and raises on fit. Use mode="strict" for MAE.

  • mode="fast" relaxes the leakage guard slightly: bins are fit once per node instead of per training fold, so bin edges see validation feature values — but impurity is still measured on held-out labels. It is much faster; use "strict" when the stronger guard matters.

Tests

tests/test_sklearn.py covers learning, predict_proba, clone / get_params, fast vs. strict, NaN and categorical handling, pipelines, and a full sklearn.utils.estimator_checks.check_estimator sweep. Build the extension first, then run pytest.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cvdt-0.7.0.tar.gz (86.0 kB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

cvdt-0.7.0-cp38-abi3-win_amd64.whl (290.7 kB view details)

Uploaded CPython 3.8+Windows x86-64

cvdt-0.7.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (403.1 kB view details)

Uploaded CPython 3.8+manylinux: glibc 2.17+ x86-64

cvdt-0.7.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (390.6 kB view details)

Uploaded CPython 3.8+manylinux: glibc 2.17+ ARM64

cvdt-0.7.0-cp38-abi3-macosx_11_0_arm64.whl (362.5 kB view details)

Uploaded CPython 3.8+macOS 11.0+ ARM64

cvdt-0.7.0-cp38-abi3-macosx_10_12_x86_64.whl (380.5 kB view details)

Uploaded CPython 3.8+macOS 10.12+ x86-64

File details

Details for the file cvdt-0.7.0.tar.gz.

File metadata

  • Download URL: cvdt-0.7.0.tar.gz
  • Upload date:
  • Size: 86.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for cvdt-0.7.0.tar.gz
Algorithm Hash digest
SHA256 5dd64787d94939dafdbddf0a605c178642da18c5837772ea8ad6079adc408803
MD5 6af9d6fbe48860c6637db87a3c98992c
BLAKE2b-256 58cf1354adb5c4f1067429d91fcfb99ccb9e303ef216748bfd44c52b3469b659

See more details on using hashes here.

Provenance

The following attestation bundles were made for cvdt-0.7.0.tar.gz:

Publisher: release.yml on AdventuresInDataScience/CVDT

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file cvdt-0.7.0-cp38-abi3-win_amd64.whl.

File metadata

  • Download URL: cvdt-0.7.0-cp38-abi3-win_amd64.whl
  • Upload date:
  • Size: 290.7 kB
  • Tags: CPython 3.8+, Windows x86-64
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for cvdt-0.7.0-cp38-abi3-win_amd64.whl
Algorithm Hash digest
SHA256 959c42328d6d9c5b4e1e79f90bfffb58118f6dce75abe8420eb3ad4df553ae4d
MD5 958d97a92bc2bb5d0e72bc02a6de22db
BLAKE2b-256 8ec228be72061e482c684d1f791787f54f025262e495f44707b8151e0716e0a4

See more details on using hashes here.

Provenance

The following attestation bundles were made for cvdt-0.7.0-cp38-abi3-win_amd64.whl:

Publisher: release.yml on AdventuresInDataScience/CVDT

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file cvdt-0.7.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for cvdt-0.7.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 8b69a5700a8dfddce4063e1927e9b42b5a900320fc801694afbea30050f88724
MD5 a9451b7a24dc79fc5c1a6b2ffd085408
BLAKE2b-256 a00f71109294f953ff160d0a5adfe2b787da002fb17322142f17a8c63bdc8e2b

See more details on using hashes here.

Provenance

The following attestation bundles were made for cvdt-0.7.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl:

Publisher: release.yml on AdventuresInDataScience/CVDT

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file cvdt-0.7.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl.

File metadata

File hashes

Hashes for cvdt-0.7.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm Hash digest
SHA256 86ced543d4fcfab844539ed3a80b627bed62812c4718d75a1f649581f53f4170
MD5 f028224e01f6d42f7561f7055574cab4
BLAKE2b-256 503cd57a320a5f2aeefc38910ae2275e981cba594bcb19789d959cae4f2367e0

See more details on using hashes here.

Provenance

The following attestation bundles were made for cvdt-0.7.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl:

Publisher: release.yml on AdventuresInDataScience/CVDT

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file cvdt-0.7.0-cp38-abi3-macosx_11_0_arm64.whl.

File metadata

  • Download URL: cvdt-0.7.0-cp38-abi3-macosx_11_0_arm64.whl
  • Upload date:
  • Size: 362.5 kB
  • Tags: CPython 3.8+, macOS 11.0+ ARM64
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for cvdt-0.7.0-cp38-abi3-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 9ddfe8dd2a06be722922120f7331e98665b1025ab025abee2261afc3cb24a40a
MD5 7e9d43c85cfc5632f2f218cb5ae15d6e
BLAKE2b-256 c5d00a2d71106f28458bbd5d20ef46de1ed04a8092d71fd8c9a7cdee0e2ba69a

See more details on using hashes here.

Provenance

The following attestation bundles were made for cvdt-0.7.0-cp38-abi3-macosx_11_0_arm64.whl:

Publisher: release.yml on AdventuresInDataScience/CVDT

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file cvdt-0.7.0-cp38-abi3-macosx_10_12_x86_64.whl.

File metadata

  • Download URL: cvdt-0.7.0-cp38-abi3-macosx_10_12_x86_64.whl
  • Upload date:
  • Size: 380.5 kB
  • Tags: CPython 3.8+, macOS 10.12+ x86-64
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for cvdt-0.7.0-cp38-abi3-macosx_10_12_x86_64.whl
Algorithm Hash digest
SHA256 6635fcd964ac53b77a551edd5db68d33b432536c8242d761467058e8af5d3484
MD5 6d3df23066713fe25482db2c06bd4f3b
BLAKE2b-256 173fe21098ccecf2e5a450f3afe45a45a7f90e6e054aa197643a80ff32b6f478

See more details on using hashes here.

Provenance

The following attestation bundles were made for cvdt-0.7.0-cp38-abi3-macosx_10_12_x86_64.whl:

Publisher: release.yml on AdventuresInDataScience/CVDT

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page