Skip to main content

A tree visualization and analysis package for XGBoost models

Project description

Arborium

PyPI version License: MIT Jupyter Compatible

Interactive visualization for tree-based models in Python, with a focus on XGBoost models. Designed for use in Jupyter notebooks and similar interactive environments.

Table of Contents

Introduction

Arborium is a Python package designed to make tree-based models more interpretable through advanced visualization techniques. While tree-based models like XGBoost are powerful predictive tools, understanding how they make decisions can be challenging due to their complexity. Arborium addresses this by providing interactive, intuitive visualizations of tree structures, making it easier for data scientists and machine learning practitioners to gain insights into model behavior.

The package currently focuses on XGBoost models but plans to expand support for other tree-based algorithms in future releases.

Note: Arborium is specifically designed for use in Jupyter notebooks or similar interactive environments (JupyterLab, Google Colab, etc.) where HTML visualizations can be rendered inline.

Installation

Basic Installation

pip install arborium

With XGBoost Support

pip install arborium[xgboost]

Jupyter Notebook Support

Arborium requires an environment that can render HTML and JavaScript. To get the full interactive experience:

# If you don't already have Jupyter installed
pip install jupyter

# Then launch Jupyter Notebook
jupyter notebook

Development Installation

git clone https://github.com/yourusername/arborium.git
cd arborium
pip install -e ".[dev]"

Quick Start

import xgboost as xgb
from arborium import XGBTreeVisualizer
import numpy as np
from sklearn.datasets import load_breast_cancer

# Load a dataset
data = load_breast_cancer()
X, y = data.data, data.target
feature_names = data.feature_names

# Train a simple XGBoost model
model = xgb.XGBClassifier(n_estimators=10, max_depth=3)
model.fit(X, y)

# Visualize the trees
visualizer = XGBTreeVisualizer(model, X, y, feature_names=feature_names)
visualizer.show_tree()

Features

Arborium offers the following key features:

  • Interactive Tree Visualization: Explore tree structures with an intuitive, interactive interface
  • Split Point Analysis: Visualize feature distributions at split points with histograms
  • Multi-Tree Navigation: Easily navigate between trees in ensemble models
  • Simplified Tree Creation: Generate simplified decision trees that approximate complex models
  • Classification & Regression Support: Works with both classification and regression models
  • Customizable Visualizations: Control depth, components, and styling of visualizations
  • Jupyter Integration: Seamless display in Jupyter notebooks and lab environments
  • Model Insights: Gain interpretability without sacrificing model performance

Example Notebooks

For interactive examples, explore our Jupyter notebooks:

You can run these notebooks locally after installing arborium:

git clone https://github.com/rishabhmandayam/xgboost.git
cd xgboost/arborium
pip install -e .
jupyter notebook notebooks/

Or open directly in Google Colab:

Open In Colab

Usage Examples

Multiclass Classification

from arborium import XGBTreeVisualizer
from sklearn.datasets import load_iris
import xgboost as xgb

# Load regression dataset
iris = load_iris()
X, y = iris.data, iris.target

# Create DMatrix for XGBoost
dtrain = xgb.DMatrix(X, label=y)

# Set parameters for XGBoost
params = {
    'objective': 'multi:softmax',  # multiclass classification
    'num_class': 3,  # iris has 3 classes
    'max_depth': None,
    'learning_rate': 0.1,
    'eval_metric': 'mlogloss'
}

# Train XGBoost model
num_rounds = 100
model = xgb.train(params, dtrain, num_rounds)

# Create a visualizer
visualizer = XGBTreeVisualizer(model, X, y, feature_names=iris.feature_names, target_names=iris.target_names)

# Show the trees
visualizer.show_tree()

Regression

from arborium import XGBTreeVisualizer
import numpy as np
import xgboost as xgb
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

# Load a regression dataset (California Housing)
housing = fetch_california_housing()
X, y = housing.data, housing.target

# Split the data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Create DMatrix for XGBoost
dtrain_reg = xgb.DMatrix(X_train, label=y_train)
dtest_reg = xgb.DMatrix(X_test, label=y_test)

# Set parameters for regression
params_reg = {
    'objective': 'reg:squarederror',
    'max_depth': 4,
    'learning_rate': 0.1,
    'eval_metric': 'rmse'
}

# Train the regression model
num_rounds = 50
reg_model = xgb.train(params_reg, dtrain_reg, num_rounds)

# Evaluate the model
y_pred = reg_model.predict(dtest_reg)
rmse = np.sqrt(mean_squared_error(y_test, y_pred))
print(f"Regression model RMSE: {rmse:.4f}")

# Create a visualizer for the regression model
reg_vizualizer = XGBTreeVisualizer(reg_model, X_train, y_train, feature_names=housing.feature_names)

visualizer.show_tree()

Simplified Tree Representations

from arborium import XGBTreeVisualizer
from sklearn.datasets import load_iris
import xgboost as xgb

# Load regression dataset
iris = load_iris()
X, y = iris.data, iris.target

# Create DMatrix for XGBoost
dtrain = xgb.DMatrix(X, label=y)

# Set parameters for XGBoost
params = {
    'objective': 'multi:softmax',  # multiclass classification
    'num_class': 3,  # iris has 3 classes
    'max_depth': None,
    'learning_rate': 0.1,
    'eval_metric': 'mlogloss'
}

# Train XGBoost model
num_rounds = 100
model = xgb.train(params, dtrain, num_rounds)

visualizer = XGBTreeVisualizer(model, X, y, feature_names=iris.feature_names, target_names=iris.target_names)

simple_model = visualizer.show_simplified_tree(max_depth=3)

simple_predictions = simple_model.predict(X_test)

Feature Importance Visualization

Coming in a future release.

API Reference

XGBTreeVisualizer

The main class for visualizing XGBoost models.

XGBTreeVisualizer(model, X, y, feature_names=None, target_names=None)

Parameters:

  • model: A trained XGBoost model (booster or sklearn API)
  • X: Input features used during training (array-like or DataFrame)
  • y: Target values (array-like or Series)
  • feature_names: List of feature names (optional)
  • target_names: List of target class names (optional)

Methods:

  • show_tree(): Display an interactive visualization of the tree
  • show_simplified_tree(max_depth=3, n_components=None, n_samples=10000): Create and display a simplified decision tree that approximates the full model
  • get_simplified_model(): Get the simplified decision tree model object
  • predict_with_simplified_tree(X): Make predictions using the simplified model

Contributing

We welcome contributions to Arborium! If you'd like to contribute, please:

  1. Fork the repository
  2. Create a feature branch
  3. Add your changes
  4. Run the tests
  5. Submit a pull request

For major changes, please open an issue first to discuss the proposed changes.

License

Arborium is released under the MIT License. See LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

arborium-0.1.5.tar.gz (128.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

arborium-0.1.5-py3-none-any.whl (18.8 kB view details)

Uploaded Python 3

File details

Details for the file arborium-0.1.5.tar.gz.

File metadata

  • Download URL: arborium-0.1.5.tar.gz
  • Upload date:
  • Size: 128.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for arborium-0.1.5.tar.gz
Algorithm Hash digest
SHA256 0054f2af24939cafb43e2f4cc26e11a9961ec6c019463be7a7a8433f4218201c
MD5 36533ed8c04a4eaec4cbfa5361858a49
BLAKE2b-256 ff20d876c90be43cec952d2917da1f58a3a00fcc7d8453cc2df20f5a9485929b

See more details on using hashes here.

File details

Details for the file arborium-0.1.5-py3-none-any.whl.

File metadata

  • Download URL: arborium-0.1.5-py3-none-any.whl
  • Upload date:
  • Size: 18.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for arborium-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 48446cebc57cb7e9b4822020fee90f6c5e995bc0f41023f55bc7a92c0561c7f4
MD5 1a3d70acf5a3a782cc97e0bf6e470a28
BLAKE2b-256 bab60ba79f523ae130f2be4522c35deccdd04e35984981fcbea60beb52039d1f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page