Skip to main content

A package for building decision trees and random forests with custom splitting criteria.

This project has been archived.

The maintainers of this project have marked this project as archived. No new releases are expected.

Project description


Custom Tree Classifier

Static Badge GitHub License PyPI - Downloads Ruff

Custom Tree Classifier is a Python package that allows building decision trees and random forests with custom splitting criteria, enabling optimization for specific problems. Users can define metrics like Gini, economic profit, or any custom cost function.

This flexibility is particularly useful in "cost-dependent" scenarios.

drawing

Examples of use

Here are some examples of how custom splitting criteria can be beneficial:

  • Trading Movements Classification: When the goal is to maximize economic profit, the metric can be set to economic profit, optimizing tree splitting accordingly.
  • Churn Prediction: To minimize false negatives, metrics like F1 score or recall can guide the splitting process.
  • Fraud Detection: Splitting can be optimized based on the proportion of fraudulent transactions identified relative to the total, rather than overall classification accuracy.
  • Marketing Campaigns: The splitting can focus on maximizing expected revenue from customer segments identified by the tree.

Usage

See ./notebooks/Example.ipynb for a complete example.

Installation

pip install custom_tree_classifier

Define your metric

To integrate a specific measure, the user must define a class containing the compute_metric and compute_delta methods, then insert this class into the classifier.

Example of a class with the Gini index :

import numpy as np

from custom_tree_classifier.metrics import MetricBase

class Gini(MetricBase):

    @staticmethod
    def compute_metric(metric_data: np.ndarray) -> np.float64:

        y = metric_data[:, 0]

        prop0 = np.sum(y == 0) / len(y)
        prop1 = np.sum(y == 1) / len(y)

        metric = 1 - (prop0**2 + prop1**2)

        return metric

    @staticmethod
    def compute_delta(
            split: np.ndarray,
            metric_data: np.ndarray
        ) -> np.float64:

        delta = (
            Gini.compute_metric(metric_data) -
            Gini.compute_metric(metric_data[split]) * np.mean(split) -
            Gini.compute_metric(metric_data[np.invert(split)]) * (1 - np.mean(split))
        )

        return delta

Train and predict

Once you have instantiated the model with your custom metric, all you have to do is use the .fit and .predict_proba methods:

from custom_tree_classifier import CustomRandomForestClassifier

model = CustomDecisionTreeClassifier(
    max_depth=3,
    metric=Gini
)

model.fit(
    X=X_train, 
    y=y_train, 
    metric_data=metric_data
)

probas = model.predict_proba(
    X=X_test
)

probas[:5]
>>> array([[0.75308642, 0.24691358],
           [0.36206897, 0.63793103],
           [0.75308642, 0.24691358],
           [0.36206897, 0.63793103],
           [0.90243902, 0.09756098]])

Print the tree

You can also display the decision tree, with the values of your metrics, using the print_tree method:

features_names = {
    0: "Pclass", 
    1: "Age"
}

model.print_tree(
    features_names=features_names,
    digits=2,
    metric_name="MyMetric"
)
>>> [1] -> MyMetric = 0.48 | repartition = [424, 290]
    |    Δ MyMetric = +0.05
    |   [2] Pclass <= 2.0 -> MyMetric = 0.49 | repartition = [154, 205]
    |   |    Δ MyMetric = +0.03
    |   |   |    Δ MyMetric = +0.01
    |   |   |    Δ MyMetric = +0.03
    |   [3] Pclass > 2.0 -> MyMetric = 0.36 | repartition = [270, 85]
    |   |    Δ MyMetric = +0.02
    |   |   |    Δ MyMetric = +0.04
    |   |   |    Δ MyMetric = +0.01

Random Forest Classifier

Same with Random Forest Classifier :

from custom_tree_classifier import CustomRandomForestClassifier

random_forest = CustomRandomForestClassifier(
    n_estimators=100,
    max_depth=5,
    metric=Gini
)

random_forest.fit(
    X=X_train, 
    y=y_train, 
    metric_data=metric_data
)

probas = random_forest.predict_proba(
    X=X_test
)

Reminder on splitting criteria

Typically, classification trees are constructed using a splitting criterion that is based on a measure of impurity or information gain.

Let us consider a 2-class classification using the Gini index as metric. The Gini index represents the impurity of a group of observations based on the proportion of observations in each class 0 and 1 :

$$ I_{G} = 1 - p_0^2 - p_1^2 $$

Since the Gini index is an indicator of impurity, partitioning is done by minimising the weighted average of the index in the child nodes $L$ and $R$. This is equivalent to minimising $\Delta$ :

$$ \Delta = \frac{N_t}{N} \times (I_G - \frac{N_{t_L} * I_{G_L}}{N_t} - \frac{N_{t_R} * I_{G_R}}{N_t}) $$

At each node, the tree algorithm finds the split that minimizes $\Delta$ over all possible splits and over all features. Once the optimal split is selected, the tree is grown by recursively applying this splitting process to the resulting child nodes.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

custom_tree_classifier-1.0.4.tar.gz (12.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

custom_tree_classifier-1.0.4-py3-none-any.whl (15.9 kB view details)

Uploaded Python 3

File details

Details for the file custom_tree_classifier-1.0.4.tar.gz.

File metadata

  • Download URL: custom_tree_classifier-1.0.4.tar.gz
  • Upload date:
  • Size: 12.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.9

File hashes

Hashes for custom_tree_classifier-1.0.4.tar.gz
Algorithm Hash digest
SHA256 eddb481501954f43c59c4cef2db9d2fd7382db7bb42692d66fde8b539c66e374
MD5 aa20a504b7f2093cd3b57334aa9b0f7b
BLAKE2b-256 a4beb0b17ec3389a13ed56256302a7ac97826a8870583637a5820135e55d7f27

See more details on using hashes here.

File details

Details for the file custom_tree_classifier-1.0.4-py3-none-any.whl.

File metadata

File hashes

Hashes for custom_tree_classifier-1.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 d853100ad2d7de98173acc7f0e6fbc9877b30d227169257a11200a335ff944b8
MD5 edc80aac1d258157a3ef78f45177bd3b
BLAKE2b-256 b6f7c457eeda1c42f5a7edec65106a707aff279a2d6f74f1a7e0a97e7515e46b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page