Skip to main content

Cross-Validated Decision Tree — scikit-learn-compatible estimators backed by a Rust core

Project description

CVDT for Python (scikit-learn compatible)

cvdt ships two estimators that follow the scikit-learn API and are backed by the zero-dependency Rust core:

  • cvdt.CVDTClassifierClassifierMixin, predict / predict_proba / predict_log_proba
  • cvdt.CVDTRegressorRegressorMixin, predict

Both are full BaseEstimators: they work with Pipeline, cross_val_score, GridSearchCV, clone, and get_params / set_params.

How the binding is layered

scikit-learn API  ──  python/cvdt/_estimator.py   (validation, tags, classes_, label encoding)
                          │  numpy arrays
native extension  ──  cvdt._cvdt  (src/python.rs, PyO3)   (numpy → flat Columns → tree)
                          │
Rust core         ──  DecisionTree<Task>  (zero-dependency crate)

All sklearn semantics live in Python; the Rust side stays a thin, fast core. The core crate remains zero-dependency — PyO3 and the numpy crate are pulled in only under the python cargo feature. A plain cargo build / cargo test still compiles nothing but std.

Building

Requires a Rust toolchain and maturin.

pip install maturin numpy scikit-learn
# dev build into the current environment:
maturin develop --release --features python
# or build a wheel:
maturin build --release --features python
pip install target/wheels/cvdt-*.whl

Wheels are built with PyO3's abi3-py38, so a single wheel works across CPython 3.8+.

Quickstart

import numpy as np
from cvdt import CVDTClassifier, CVDTRegressor

clf = CVDTClassifier(criterion="gini", cv_folds=5, max_depth=6)
clf.fit(X_train, y_train)
clf.predict(X_test)
clf.predict_proba(X_test)
clf.score(X_test, y_test)

reg = CVDTRegressor(criterion="mse", mode="fast", n_bins=32)
reg.fit(X_train, y_train)

Everything composes with sklearn:

from sklearn.model_selection import GridSearchCV
grid = GridSearchCV(
    CVDTClassifier(),
    {"max_depth": [4, 8], "aggregator": ["mean", "mean_minus_lambda_std"], "mode": ["strict", "fast"]},
    cv=5,
).fit(X, y)

See examples/sklearn_usage.py.

Parameters

Common to both estimators:

Parameter Default Meaning
max_depth 8 Max split levels; None for unlimited.
min_samples_split 2 Min samples for a node to be splittable.
min_samples_leaf 1 Min samples each child must receive.
min_impurity_decrease 0.0 Min aggregated gain to accept a split.
n_bins 8 Quantile bins for continuous features.
cv_folds 5 K in the K-fold split evaluation.
cv_seed 42 Fold-shuffle seed (determinism).
cv_shuffle True Shuffle before folding.
mode "strict" "strict" (per-fold edges) or "fast" (histogram path).
aggregator "mean" mean, median, trimmed_mean, signal_to_noise, mean_minus_lambda_std.
agg_frac 0.1 Trim fraction for trimmed_mean.
agg_eps 1e-12 Stabiliser for signal_to_noise.
agg_lambda 1.0 Std penalty for mean_minus_lambda_std.
parallel False Evaluate features in parallel.
n_threads 1 Worker threads when parallel.
parallel_min_samples 512 Only parallelise nodes at least this large.
categorical_features None Column indices to treat as categorical.

Estimator-specific:

  • CVDTClassifier.criterion: "gini" (default) or "entropy".
  • CVDTRegressor.criterion: "mse" (default), "variance", or "mae".

Objective-driven classification

Instead of an impurity proxy, CVDTClassifier can select splits to directly optimise a target metric on the held-out folds:

Parameter Default Meaning
objective None None (use criterion), or precision/recall/f1/fbeta/accuracy.
average "binary" binary, micro, macro, or weighted.
pos_label 1 Positive class (index into sorted classes_) when average="binary".
beta 1.0 β for objective="fbeta".
# Optimise recall of the positive class on an imbalanced problem:
clf = CVDTClassifier(objective="recall", average="binary", pos_label=1).fit(X, y)

# Macro-F1 for multiclass:
clf = CVDTClassifier(objective="f1", average="macro").fit(X, y)

When objective is set, criterion is ignored. Splits are accepted only when they improve the metric over making the node a leaf, so objective-mode trees are usually shallower and tuned to the metric; set min_impurity_decrease below 0 to allow non-improving splits.

Notes and caveats

  • Interpretability: a fitted CVDTClassifier/CVDTRegressor is a single tree, so est.get_depth() and est.get_n_leaves() (sklearn-style accessors) report its size; total nodes of the binary tree is 2 * get_n_leaves() - 1. The tree can be inspected and visualised:

    • est.export_text(feature_names=..., class_names=...) — readable rule dump.
    • est.export_graphviz(...) — Graphviz DOT string (render with graphviz, pydot, or dtreeviz).
    • est.get_tree() — the tree as a dict of parallel arrays for custom plots.

    Note CVDT splits are membership tests, so continuous branches read as intervals (lo <= x[f] < hi) rather than single thresholds, and the "true" branch is the one that matches (missing values route to "false").

  • categorical_features columns are read as integer category ids. Feed already-integer-encoded values (e.g. from OrdinalEncoder); values are rounded to the nearest non-negative integer, and negatives / non-finite are treated as an unknown category (routed to the "not-in-state" child).

  • Missing values: NaN / non-finite continuous values are accepted and routed right, so allow_nan is set in the estimator tags.

  • mode="fast" + criterion="mae" is unsupported (the median has no additive sufficient statistic for the histogram path) and raises on fit. Use mode="strict" for MAE.

  • mode="fast" relaxes the leakage guard slightly: bins are fit once per node instead of per training fold, so bin edges see validation feature values — but impurity is still measured on held-out labels. It is much faster; use "strict" when the stronger guard matters.

Tests

tests/test_sklearn.py covers learning, predict_proba, clone / get_params, fast vs. strict, NaN and categorical handling, pipelines, and a full sklearn.utils.estimator_checks.check_estimator sweep. Build the extension first, then run pytest.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cvdt-0.6.0.tar.gz (80.8 kB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

cvdt-0.6.0-cp38-abi3-win_amd64.whl (283.2 kB view details)

Uploaded CPython 3.8+Windows x86-64

cvdt-0.6.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (394.3 kB view details)

Uploaded CPython 3.8+manylinux: glibc 2.17+ x86-64

cvdt-0.6.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (381.5 kB view details)

Uploaded CPython 3.8+manylinux: glibc 2.17+ ARM64

cvdt-0.6.0-cp38-abi3-macosx_11_0_arm64.whl (357.9 kB view details)

Uploaded CPython 3.8+macOS 11.0+ ARM64

cvdt-0.6.0-cp38-abi3-macosx_10_12_x86_64.whl (370.8 kB view details)

Uploaded CPython 3.8+macOS 10.12+ x86-64

File details

Details for the file cvdt-0.6.0.tar.gz.

File metadata

  • Download URL: cvdt-0.6.0.tar.gz
  • Upload date:
  • Size: 80.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for cvdt-0.6.0.tar.gz
Algorithm Hash digest
SHA256 7cce6a90f1ffff77c36bff06b662e577cc3eaf88eb0341a73fca59506e344719
MD5 9e9bb8989aaa77b53eec6738356a819c
BLAKE2b-256 169e4ac251e51803176140bd5fccfc2458a6c6bdda7d753446b24fbb68379db3

See more details on using hashes here.

Provenance

The following attestation bundles were made for cvdt-0.6.0.tar.gz:

Publisher: release.yml on AdventuresInDataScience/CVDT

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file cvdt-0.6.0-cp38-abi3-win_amd64.whl.

File metadata

  • Download URL: cvdt-0.6.0-cp38-abi3-win_amd64.whl
  • Upload date:
  • Size: 283.2 kB
  • Tags: CPython 3.8+, Windows x86-64
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for cvdt-0.6.0-cp38-abi3-win_amd64.whl
Algorithm Hash digest
SHA256 fc80726df4bdac34e19e649332116d9e9eca61010ff977eb5e0a65a86cfa099f
MD5 5f63525d20ac12de80f80187ca8d1f70
BLAKE2b-256 60f6e3e7af5afca157b3ed2aab8d05330c9aee9c1fefa3c09d451c3af9ffe1d2

See more details on using hashes here.

Provenance

The following attestation bundles were made for cvdt-0.6.0-cp38-abi3-win_amd64.whl:

Publisher: release.yml on AdventuresInDataScience/CVDT

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file cvdt-0.6.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for cvdt-0.6.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 dfcd5828061da9ddb661cb7922a11cbefa6ac9ce2b6988a1828306491c14cf57
MD5 03a080ce199857376d2f43242083f665
BLAKE2b-256 aea8b08636f73280e44bdd5e469ba7e9fa08fdd5d061fd1b9594adb3b3bad84b

See more details on using hashes here.

Provenance

The following attestation bundles were made for cvdt-0.6.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl:

Publisher: release.yml on AdventuresInDataScience/CVDT

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file cvdt-0.6.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl.

File metadata

File hashes

Hashes for cvdt-0.6.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm Hash digest
SHA256 032cbca7089f59ed4746298386fe0c0650a76b9ba7969dcc7b131135d2a70893
MD5 16131154140be419091eb1b99766cc02
BLAKE2b-256 0bceb3f87db65aed711420d12ba598267c6da4a5bac9d223dc8052d74a3d8ca4

See more details on using hashes here.

Provenance

The following attestation bundles were made for cvdt-0.6.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl:

Publisher: release.yml on AdventuresInDataScience/CVDT

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file cvdt-0.6.0-cp38-abi3-macosx_11_0_arm64.whl.

File metadata

  • Download URL: cvdt-0.6.0-cp38-abi3-macosx_11_0_arm64.whl
  • Upload date:
  • Size: 357.9 kB
  • Tags: CPython 3.8+, macOS 11.0+ ARM64
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for cvdt-0.6.0-cp38-abi3-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 54dd28e6dec96527a66641308b9417dfd33246118fb868f4092809c1c5a5c4e4
MD5 ea86d4fd787d7585093d7dc01f1d91d9
BLAKE2b-256 4fe1ddede87e958d291bc9d95a8fac4cb5950e77c471e2118f30d6adb561f897

See more details on using hashes here.

Provenance

The following attestation bundles were made for cvdt-0.6.0-cp38-abi3-macosx_11_0_arm64.whl:

Publisher: release.yml on AdventuresInDataScience/CVDT

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file cvdt-0.6.0-cp38-abi3-macosx_10_12_x86_64.whl.

File metadata

  • Download URL: cvdt-0.6.0-cp38-abi3-macosx_10_12_x86_64.whl
  • Upload date:
  • Size: 370.8 kB
  • Tags: CPython 3.8+, macOS 10.12+ x86-64
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for cvdt-0.6.0-cp38-abi3-macosx_10_12_x86_64.whl
Algorithm Hash digest
SHA256 af7ded48c91d57dfa1bfccce3b65372f5fda98e7362ac071ca8613eb80f1ffdb
MD5 adcf67ea4ee2f6901d56d67ba2ef5721
BLAKE2b-256 194605b32542d02613bd006b7785da181185ed2c9990ffad5986342478a80ff2

See more details on using hashes here.

Provenance

The following attestation bundles were made for cvdt-0.6.0-cp38-abi3-macosx_10_12_x86_64.whl:

Publisher: release.yml on AdventuresInDataScience/CVDT

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page