Skip to main content

Cross-Validated Decision Tree — scikit-learn-compatible estimators backed by a Rust core

Project description

CVDT for Python (scikit-learn compatible)

cvdt ships two estimators that follow the scikit-learn API and are backed by the zero-dependency Rust core:

  • cvdt.CVDTClassifierClassifierMixin, predict / predict_proba / predict_log_proba
  • cvdt.CVDTRegressorRegressorMixin, predict

Both are full BaseEstimators: they work with Pipeline, cross_val_score, GridSearchCV, clone, and get_params / set_params.

How the binding is layered

scikit-learn API  ──  python/cvdt/_estimator.py   (validation, tags, classes_, label encoding)
                          │  numpy arrays
native extension  ──  cvdt._cvdt  (src/python.rs, PyO3)   (numpy → flat Columns → tree)
                          │
Rust core         ──  DecisionTree<Task>  (zero-dependency crate)

All sklearn semantics live in Python; the Rust side stays a thin, fast core. The core crate remains zero-dependency — PyO3 and the numpy crate are pulled in only under the python cargo feature. A plain cargo build / cargo test still compiles nothing but std.

Building

Requires a Rust toolchain and maturin.

pip install maturin numpy scikit-learn
# dev build into the current environment:
maturin develop --release --features python
# or build a wheel:
maturin build --release --features python
pip install target/wheels/cvdt-*.whl

Wheels are built with PyO3's abi3-py38, so a single wheel works across CPython 3.8+.

Quickstart

import numpy as np
from cvdt import CVDTClassifier, CVDTRegressor

clf = CVDTClassifier(criterion="gini", cv_folds=5, max_depth=6)
clf.fit(X_train, y_train)
clf.predict(X_test)
clf.predict_proba(X_test)
clf.score(X_test, y_test)

reg = CVDTRegressor(criterion="mse", mode="fast", n_bins=32)
reg.fit(X_train, y_train)

Everything composes with sklearn:

from sklearn.model_selection import GridSearchCV
grid = GridSearchCV(
    CVDTClassifier(),
    {"max_depth": [4, 8], "aggregator": ["mean", "mean_minus_lambda_std"], "mode": ["strict", "fast"]},
    cv=5,
).fit(X, y)

See examples/sklearn_usage.py.

Parameters

Common to both estimators:

Parameter Default Meaning
max_depth 8 Max split levels; None for unlimited.
min_samples_split 2 Min samples for a node to be splittable.
min_samples_leaf 1 Min samples each child must receive.
min_impurity_decrease 0.0 Min aggregated gain to accept a split.
n_bins 8 Quantile bins for continuous features.
cv_folds 5 K in the K-fold split evaluation.
cv_seed 42 Fold-shuffle seed (determinism).
cv_shuffle True Shuffle before folding.
mode "strict" "strict" (per-fold edges) or "fast" (histogram path).
aggregator "mean" mean, median, trimmed_mean, signal_to_noise, mean_minus_lambda_std.
agg_frac 0.1 Trim fraction for trimmed_mean.
agg_eps 1e-12 Stabiliser for signal_to_noise.
agg_lambda 1.0 Std penalty for mean_minus_lambda_std.
parallel False Evaluate features in parallel.
n_threads 1 Worker threads when parallel.
parallel_min_samples 512 Only parallelise nodes at least this large.
categorical_features None Column indices to treat as categorical.

Estimator-specific:

  • CVDTClassifier.criterion: "gini" (default) or "entropy".
  • CVDTRegressor.criterion: "mse" (default), "variance", or "mae".

Objective-driven classification

Instead of an impurity proxy, CVDTClassifier can select splits to directly optimise a target metric on the held-out folds:

Parameter Default Meaning
objective None None (use criterion), or precision/recall/f1/fbeta/accuracy.
average "binary" binary, micro, macro, or weighted.
pos_label 1 Positive class (index into sorted classes_) when average="binary".
beta 1.0 β for objective="fbeta".
# Optimise recall of the positive class on an imbalanced problem:
clf = CVDTClassifier(objective="recall", average="binary", pos_label=1).fit(X, y)

# Macro-F1 for multiclass:
clf = CVDTClassifier(objective="f1", average="macro").fit(X, y)

When objective is set, criterion is ignored. Splits are accepted only when they improve the metric over making the node a leaf, so objective-mode trees are usually shallower and tuned to the metric; set min_impurity_decrease below 0 to allow non-improving splits.

Notes and caveats

  • Interpretability: a fitted CVDTClassifier/CVDTRegressor is a single tree, so est.get_depth() and est.get_n_leaves() (sklearn-style accessors) report its size; total nodes of the binary tree is 2 * get_n_leaves() - 1. The tree can be inspected and visualised:

    • est.export_text(feature_names=..., class_names=...) — readable rule dump.
    • est.export_graphviz(...) — Graphviz DOT string (render with graphviz, pydot, or dtreeviz).
    • est.get_tree() — the tree as a dict of parallel arrays for custom plots.

    Note CVDT splits are membership tests, so continuous branches read as intervals (lo <= x[f] < hi) rather than single thresholds, and the "true" branch is the one that matches (missing values route to "false").

  • categorical_features columns are read as integer category ids. Feed already-integer-encoded values (e.g. from OrdinalEncoder); values are rounded to the nearest non-negative integer, and negatives / non-finite are treated as an unknown category (routed to the "not-in-state" child).

  • Missing values: NaN / non-finite continuous values are accepted and routed right, so allow_nan is set in the estimator tags.

  • mode="fast" + criterion="mae" is unsupported (the median has no additive sufficient statistic for the histogram path) and raises on fit. Use mode="strict" for MAE.

  • mode="fast" relaxes the leakage guard slightly: bins are fit once per node instead of per training fold, so bin edges see validation feature values — but impurity is still measured on held-out labels. It is much faster; use "strict" when the stronger guard matters.

Tests

tests/test_sklearn.py covers learning, predict_proba, clone / get_params, fast vs. strict, NaN and categorical handling, pipelines, and a full sklearn.utils.estimator_checks.check_estimator sweep. Build the extension first, then run pytest.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cvdt-0.5.0.tar.gz (78.3 kB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

cvdt-0.5.0-cp38-abi3-win_amd64.whl (283.2 kB view details)

Uploaded CPython 3.8+Windows x86-64

cvdt-0.5.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (394.2 kB view details)

Uploaded CPython 3.8+manylinux: glibc 2.17+ x86-64

cvdt-0.5.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (381.5 kB view details)

Uploaded CPython 3.8+manylinux: glibc 2.17+ ARM64

cvdt-0.5.0-cp38-abi3-macosx_11_0_arm64.whl (358.0 kB view details)

Uploaded CPython 3.8+macOS 11.0+ ARM64

cvdt-0.5.0-cp38-abi3-macosx_10_12_x86_64.whl (370.8 kB view details)

Uploaded CPython 3.8+macOS 10.12+ x86-64

File details

Details for the file cvdt-0.5.0.tar.gz.

File metadata

  • Download URL: cvdt-0.5.0.tar.gz
  • Upload date:
  • Size: 78.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for cvdt-0.5.0.tar.gz
Algorithm Hash digest
SHA256 7cfb9ec8e4b308987b747067595b07fffa05ca0edda782cb015f26f6016524c6
MD5 a912ca5c297e80d46ee301ad2ab135e4
BLAKE2b-256 d1fa8b3a434c2a28ec1a134a417ae30eb1c6afbd6a0b9f1d1f18caa5f6f20584

See more details on using hashes here.

Provenance

The following attestation bundles were made for cvdt-0.5.0.tar.gz:

Publisher: release.yml on AdventuresInDataScience/CVDT

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file cvdt-0.5.0-cp38-abi3-win_amd64.whl.

File metadata

  • Download URL: cvdt-0.5.0-cp38-abi3-win_amd64.whl
  • Upload date:
  • Size: 283.2 kB
  • Tags: CPython 3.8+, Windows x86-64
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for cvdt-0.5.0-cp38-abi3-win_amd64.whl
Algorithm Hash digest
SHA256 864eefbb6c5c4eb91a8358d6ee41dc648e2b0e01bcc9765d04bcd7f1f850123e
MD5 32bcfb3babd09f4ef1b0d840311836fd
BLAKE2b-256 9faf7b298a515e48c6e0a40f5112c138496546a170513a504177088ebb5d638d

See more details on using hashes here.

Provenance

The following attestation bundles were made for cvdt-0.5.0-cp38-abi3-win_amd64.whl:

Publisher: release.yml on AdventuresInDataScience/CVDT

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file cvdt-0.5.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for cvdt-0.5.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 78bb16da390f049f76046da317583e78b26d42dcc83992c16d8e0409084ccfe5
MD5 415c0730a51df9a6123c87ab903c935e
BLAKE2b-256 3118aa8ad2fe9006b7495078e157dbd3c7b9bf1ba8c078dabc50a65f90f7fdfc

See more details on using hashes here.

Provenance

The following attestation bundles were made for cvdt-0.5.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl:

Publisher: release.yml on AdventuresInDataScience/CVDT

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file cvdt-0.5.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl.

File metadata

File hashes

Hashes for cvdt-0.5.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm Hash digest
SHA256 41f0746ea348125120365aef46fd86154e392ba944b2d4a029ca7bc3fbfa48fb
MD5 f38a206212b40c6b10ec2223c72b21ac
BLAKE2b-256 05b659a6e0bd012a0c438beab193a2988ce7c0cd7be4d73c8280d5942b0c1d14

See more details on using hashes here.

Provenance

The following attestation bundles were made for cvdt-0.5.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl:

Publisher: release.yml on AdventuresInDataScience/CVDT

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file cvdt-0.5.0-cp38-abi3-macosx_11_0_arm64.whl.

File metadata

  • Download URL: cvdt-0.5.0-cp38-abi3-macosx_11_0_arm64.whl
  • Upload date:
  • Size: 358.0 kB
  • Tags: CPython 3.8+, macOS 11.0+ ARM64
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for cvdt-0.5.0-cp38-abi3-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 bb3955f092863941ec47e80fbbb146329267149a8d4c60995160e2ed6e01944a
MD5 5e43ea5330801fbdf7a90ade55c737dd
BLAKE2b-256 812707e27c64b6ec2c61b407ac6e3900ee16f744ed2331a837703a82f4962674

See more details on using hashes here.

Provenance

The following attestation bundles were made for cvdt-0.5.0-cp38-abi3-macosx_11_0_arm64.whl:

Publisher: release.yml on AdventuresInDataScience/CVDT

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file cvdt-0.5.0-cp38-abi3-macosx_10_12_x86_64.whl.

File metadata

  • Download URL: cvdt-0.5.0-cp38-abi3-macosx_10_12_x86_64.whl
  • Upload date:
  • Size: 370.8 kB
  • Tags: CPython 3.8+, macOS 10.12+ x86-64
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for cvdt-0.5.0-cp38-abi3-macosx_10_12_x86_64.whl
Algorithm Hash digest
SHA256 368c16f517607a9d91a37c0578331cb16200f40c5b636386eb8addc28a36d36b
MD5 acd4037e37d7d29fa00bb03ebd88663d
BLAKE2b-256 a277011c32cd24714bc707941f93f255cdb0f4193c0ae97895f752c14c90638c

See more details on using hashes here.

Provenance

The following attestation bundles were made for cvdt-0.5.0-cp38-abi3-macosx_10_12_x86_64.whl:

Publisher: release.yml on AdventuresInDataScience/CVDT

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page