Skip to main content

creating and calculating groves of tree models

Project description

xGrove Python Package

Das xgrove-Paket bietet eine Klasse zur Berechnung von "Surrogate Groves", um Entscheidungsbäume zu interpretieren. Es ist inspiriert von Methoden aus dem Bereich der Interpretable Machine Learning (IML) und bietet eine Reihe von Funktionen zur Analyse und Visualisierung von Entscheidungsbaumstrukturen.

Installation

Stelle sicher, dass die erforderlichen Abhängigkeiten installiert sind:

pip install -r requirements.txt

Klassen und Methoden

Klasse: xgrove

Die Hauptklasse xgrove wird verwendet, um "Surrogate Groves" zu erstellen und statistische Analysen durchzuführen.

Konstruktor

xgrove(
    model, 
    data: pd.DataFrame, 
    ntrees: np.array = np.array([4, 8, 16, 32, 64, 128]), 
    pfun = None, 
    shrink: int = 1, 
    b_frac: int = 1, 
    seed: int = 42, 
    grove_rate: float = 1
)
Parameter:
  • model: Das zu analysierende Modell, typischerweise ein beliebiges ML-Modell.
  • data: Ein pandas.DataFrame, das die Eingabedaten enthält.
  • ntrees: Ein np.array, das die Anzahl der Bäume im Grove angibt.
  • pfun: Eine Funktion zur Erstellung des Surrogate-Ziels. Falls None, wird das Modell zur Vorhersage genutzt.
  • shrink: Der Shrinkage-Faktor für das Gradient Boosting.
  • b_frac: Die Fraktion der Stichprobe, die verwendet wird.
  • seed: Der Seed für die Reproduzierbarkeit.
  • grove_rate: Die Lernrate für das Grove.

Methode: getSurrogateTarget()

Erzeugt das Surrogate-Ziel basierend auf den Eingabedaten und dem Modell oder der benutzerdefinierten pfun.

def getSurrogateTarget(self, pfun):
    if self.pfun is None:
        target = self.model.predict(self.data)
    else:
        target = pfun(model=self.model, data=self.data)
    return target

Methode: getGBM()

Erzeugt ein Gradient Boosting Modell (GBM) mit den angegebenen Parametern.

def getGBM(self):
    grove = GradientBoostingRegressor(
        n_estimators=self.ntrees,
        learning_rate=self.shrink,
        subsample=self.b_frac
    )
    return grove

Methode: encodeCategorical()

Codiert kategoriale Variablen mithilfe von One-Hot-Encoding (OHE).

def encodeCategorical(self):
    categorical_columns = self.data.select_dtypes(include=['object', 'category']).columns
    data_encoded = pd.get_dummies(data, columns=categorical_columns)
    return data_encoded

Methode: upsilon()

Berechnet die Upsilon-Statistik, die das Verhältnis zwischen erklärtem und unerklärtem Fehler angibt, sowie die Korrelation zwischen den Vorhersagen des Modells und den echten Werten.

def upsilon(self, pexp):
    ASE = statistics.mean((self.surrTar - pexp) ** 2)
    ASE0 = statistics.mean((self.surrTar - statistics.mean(self.surrTar)) ** 2)
    ups = 1 - ASE / ASE0
    rho = statistics.correlation(self.surrTar, pexp)
    return ups, rho

Methode: get_result()

Gibt eine Liste der zentralen Ergebnisse zurück: Erklärung, Regeln, Groves und Modell.

def get_result(self):
    res = [self.explanation, self.rules, self.groves, self.model]
    return res

Methode: plot()

Eine Methode zur Erstellung eines Upsilon-Rules-Plots für den Surrogate Grove. Diese Methode funktioniert ähnlich wie die Plotfunktion in R.

def plot(self, abs="rules", ord="upsilon"):
    i = self.explanation.columns.get_loc(abs)
    j = self.explanation.columns.get_loc(ord)
    plt.plot(self.explanation.iloc[:, i], self.explanation.iloc[:, j], label=f"{abs} vs {ord}", marker="o")
    plt.xlabel(abs)
    plt.ylabel(ord)
    plt.title("Upsilon-Rules Curve")
    plt.show()

Methode: calculateGrove()

Berechnet die Performance des Modells und extrahiert Groves sowie die dazugehörigen Regeln. Diese Methode füllt die Erklärungs- und Interpretationsdaten und ruft am Ende die upsilon-Methode auf, um den Upsilon-Wert zu berechnen.

def calculateGrove(self):
    explanation = []
    groves = []
    interpretation = []

    # Für jede Anzahl an Bäumen
    for nt in self.ntrees:
        predictions = self.surrGrove.staged_predict(self.data)
        predictions = [next(predictions) for _ in range(nt)][-1]
        rules = []
        
        # Extrahiere Regeln aus den Entscheidungsbäumen
        for tid in range(nt):
            tree = self.surrGrove.estimators_[tid, 0].tree_
            for node_id in range(tree.node_count):
                if tree.children_left[node_id] != tree.children_right[node_id]:  # Splitsnode
                    rule = {
                        'feature': tree.feature[node_id],
                        'threshold': tree.threshold[node_id],
                        'pleft': tree.value[tree.children_left[node_id]][0][0],
                        'pright': tree.value[tree.children_right[node_id]][0][0]
                    }
                    rules.append(rule)
            rules_df = pd.DataFrame(rules)
            groves.append(rules_df)

        # Berechne Upsilon und Korrelation
        upsilon, rho = self.upsilon(predictions)

        # Ergebnisse speichern
        explanation.append([nt, len(rules_df), upsilon, rho])

    # Ergebnisdaten aufbereiten
    groves = pd.DataFrame(groves)
    explanation = pd.DataFrame(explanation, columns=["trees", "rules", "upsilon", "cor"])
    
    self.explanation = explanation
    self.rules = groves
    self.model = self.surrGrove

    self.result = self.get_result()
    return self.result

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

xgrove-0.3.27.tar.gz (14.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

xgrove-0.3.27-py3-none-any.whl (16.9 kB view details)

Uploaded Python 3

File details

Details for the file xgrove-0.3.27.tar.gz.

File metadata

  • Download URL: xgrove-0.3.27.tar.gz
  • Upload date:
  • Size: 14.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.13.0

File hashes

Hashes for xgrove-0.3.27.tar.gz
Algorithm Hash digest
SHA256 b21471ce75b39a45ee64085a499b7927bd85fa1fea654be567511942382ec11f
MD5 19bdb98a9b6818c676f82bb02cb70ba2
BLAKE2b-256 ad627c7cd7d0adca68abe281e0ab2099afc22916815d01e5939a23743205c32e

See more details on using hashes here.

File details

Details for the file xgrove-0.3.27-py3-none-any.whl.

File metadata

  • Download URL: xgrove-0.3.27-py3-none-any.whl
  • Upload date:
  • Size: 16.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.13.0

File hashes

Hashes for xgrove-0.3.27-py3-none-any.whl
Algorithm Hash digest
SHA256 2f8faa8c6f3175d9be3241a1975edcabd2ef201bf4f38f988da9b42f255a2b26
MD5 818b34609412d24876f316a7cc1211b2
BLAKE2b-256 3a42fbea44a9412371398f6489da25654a877ca55f1aba19bdc212bad7dfbd0b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page