Skip to main content

creating and calculating groves of tree models

Project description

xGrove Python Package

Das xgrove-Paket bietet eine Klasse zur Berechnung von "Surrogate Groves", um Entscheidungsbäume zu interpretieren. Es ist inspiriert von Methoden aus dem Bereich der Interpretable Machine Learning (IML) und bietet eine Reihe von Funktionen zur Analyse und Visualisierung von Entscheidungsbaumstrukturen.

Installation

Stelle sicher, dass die erforderlichen Abhängigkeiten installiert sind:

pip install -r requirements.txt

Klassen und Methoden

Klasse: xgrove

Die Hauptklasse xgrove wird verwendet, um "Surrogate Groves" zu erstellen und statistische Analysen durchzuführen.

Konstruktor

xgrove(
    model, 
    data: pd.DataFrame, 
    ntrees: np.array = np.array([4, 8, 16, 32, 64, 128]), 
    pfun = None, 
    shrink: int = 1, 
    b_frac: int = 1, 
    seed: int = 42, 
    grove_rate: float = 1
)
Parameter:
  • model: Das zu analysierende Modell, typischerweise ein beliebiges ML-Modell.
  • data: Ein pandas.DataFrame, das die Eingabedaten enthält.
  • ntrees: Ein np.array, das die Anzahl der Bäume im Grove angibt.
  • pfun: Eine Funktion zur Erstellung des Surrogate-Ziels. Falls None, wird das Modell zur Vorhersage genutzt.
  • shrink: Der Shrinkage-Faktor für das Gradient Boosting.
  • b_frac: Die Fraktion der Stichprobe, die verwendet wird.
  • seed: Der Seed für die Reproduzierbarkeit.
  • grove_rate: Die Lernrate für das Grove.

Methode: getSurrogateTarget()

Erzeugt das Surrogate-Ziel basierend auf den Eingabedaten und dem Modell oder der benutzerdefinierten pfun.

def getSurrogateTarget(self, pfun):
    if self.pfun is None:
        target = self.model.predict(self.data)
    else:
        target = pfun(model=self.model, data=self.data)
    return target

Methode: getGBM()

Erzeugt ein Gradient Boosting Modell (GBM) mit den angegebenen Parametern.

def getGBM(self):
    grove = GradientBoostingRegressor(
        n_estimators=self.ntrees,
        learning_rate=self.shrink,
        subsample=self.b_frac
    )
    return grove

Methode: encodeCategorical()

Codiert kategoriale Variablen mithilfe von One-Hot-Encoding (OHE).

def encodeCategorical(self):
    categorical_columns = self.data.select_dtypes(include=['object', 'category']).columns
    data_encoded = pd.get_dummies(data, columns=categorical_columns)
    return data_encoded

Methode: upsilon()

Berechnet die Upsilon-Statistik, die das Verhältnis zwischen erklärtem und unerklärtem Fehler angibt, sowie die Korrelation zwischen den Vorhersagen des Modells und den echten Werten.

def upsilon(self, pexp):
    ASE = statistics.mean((self.surrTar - pexp) ** 2)
    ASE0 = statistics.mean((self.surrTar - statistics.mean(self.surrTar)) ** 2)
    ups = 1 - ASE / ASE0
    rho = statistics.correlation(self.surrTar, pexp)
    return ups, rho

Methode: get_result()

Gibt eine Liste der zentralen Ergebnisse zurück: Erklärung, Regeln, Groves und Modell.

def get_result(self):
    res = [self.explanation, self.rules, self.groves, self.model]
    return res

Methode: plot()

Eine Methode zur Erstellung eines Upsilon-Rules-Plots für den Surrogate Grove. Diese Methode funktioniert ähnlich wie die Plotfunktion in R.

def plot(self, abs="rules", ord="upsilon"):
    i = self.explanation.columns.get_loc(abs)
    j = self.explanation.columns.get_loc(ord)
    plt.plot(self.explanation.iloc[:, i], self.explanation.iloc[:, j], label=f"{abs} vs {ord}", marker="o")
    plt.xlabel(abs)
    plt.ylabel(ord)
    plt.title("Upsilon-Rules Curve")
    plt.show()

Methode: calculateGrove()

Berechnet die Performance des Modells und extrahiert Groves sowie die dazugehörigen Regeln. Diese Methode füllt die Erklärungs- und Interpretationsdaten und ruft am Ende die upsilon-Methode auf, um den Upsilon-Wert zu berechnen.

def calculateGrove(self):
    explanation = []
    groves = []
    interpretation = []

    # Für jede Anzahl an Bäumen
    for nt in self.ntrees:
        predictions = self.surrGrove.staged_predict(self.data)
        predictions = [next(predictions) for _ in range(nt)][-1]
        rules = []
        
        # Extrahiere Regeln aus den Entscheidungsbäumen
        for tid in range(nt):
            tree = self.surrGrove.estimators_[tid, 0].tree_
            for node_id in range(tree.node_count):
                if tree.children_left[node_id] != tree.children_right[node_id]:  # Splitsnode
                    rule = {
                        'feature': tree.feature[node_id],
                        'threshold': tree.threshold[node_id],
                        'pleft': tree.value[tree.children_left[node_id]][0][0],
                        'pright': tree.value[tree.children_right[node_id]][0][0]
                    }
                    rules.append(rule)
            rules_df = pd.DataFrame(rules)
            groves.append(rules_df)

        # Berechne Upsilon und Korrelation
        upsilon, rho = self.upsilon(predictions)

        # Ergebnisse speichern
        explanation.append([nt, len(rules_df), upsilon, rho])

    # Ergebnisdaten aufbereiten
    groves = pd.DataFrame(groves)
    explanation = pd.DataFrame(explanation, columns=["trees", "rules", "upsilon", "cor"])
    
    self.explanation = explanation
    self.rules = groves
    self.model = self.surrGrove

    self.result = self.get_result()
    return self.result

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

xgrove-0.2.18.tar.gz (12.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

xgrove-0.2.18-py3-none-any.whl (14.8 kB view details)

Uploaded Python 3

File details

Details for the file xgrove-0.2.18.tar.gz.

File metadata

  • Download URL: xgrove-0.2.18.tar.gz
  • Upload date:
  • Size: 12.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.13.0

File hashes

Hashes for xgrove-0.2.18.tar.gz
Algorithm Hash digest
SHA256 a81ecde70e4d563a8faee3d0861bf8f3fab9a240a97992ec27fb38e54accb9bf
MD5 a2518504acbb9a0774f15b68c45ea0a0
BLAKE2b-256 b585e7c2710de8dd4c760686c2d9dd56fe87913f06d28dc475f287949eb07f3a

See more details on using hashes here.

File details

Details for the file xgrove-0.2.18-py3-none-any.whl.

File metadata

  • Download URL: xgrove-0.2.18-py3-none-any.whl
  • Upload date:
  • Size: 14.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.13.0

File hashes

Hashes for xgrove-0.2.18-py3-none-any.whl
Algorithm Hash digest
SHA256 ac1712d15071e9e7653490a9baaf296337213a47f8a01439f5b2f58e3ccdb045
MD5 83f9106836fe7084454194ebe7944c6a
BLAKE2b-256 63b4708ec7eb2521b8257f70b3fb335c52ba344ae67b5883a39a4f8cab654eb2

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page