Skip to main content

RFProximity

Project description

RFProximity

## Introduction

RFProximity is a python package which computes proximity matrix for any Random Forest model. Package includes three specific implementations of proximity. Additionally, using the proximity matrix, package includes methods to perform missing value imputaion, outlier detection and prototype identification.

Table of Contents

Background

Reference Material:

  • Leo Brieman and Adele Cutler's blog on Random Forest
  • Geometry and Accuracy Preserving (GAP) Proximities Arxiv Link

Usage

from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from rfproximity import TreeProximity, SimilarityExplainer
import matplotlib.pyplot as plt
import pandas as pd

# Using Iris dataset 
# using the function `sklearn.datasets.load_iris` from
# scikit-learn. The dataset will contain:
# - 3 classes 
# - n samples

data = load_iris()
X,y = data['data'], data['target']
print(X.shape)


# Using the dataset we now train a random forest classifier first
model = RandomForestClassifier().fit(X,y)
leaf_nodes = model.apply(X)


# After training the model, initialize the TreeProximity module
prox = TreeProximity(leaf_nodes)

# Using TreeProximity module calculate proximity matrix
proximity_matrix = prox.proximity_matrix()
print(proximity_matrix)

# Using TreeProximity module calculate out-of-bag proximity matrix
proximity_matrix_oob = prox.proximity_matrix_oob(model,X.shape[0])
print(proximity_matrix_oob)


# Using TreeProximity module calculate geometry and accuracy preserving
# proximity matrix
proximity_matrix_gap = prox.proximity_matrix_gap(model,X.shape[0])
print(proximity_matrix_gap)


# Using TreeProximity module to identify prototype samples

SimEx = SimilarityExplainer(proximity_matrix, y)
prototypes = SimEx.get_prototype(top_k=20, total_prototypes=3, return_neighbors=True)
print(prototypes)


# Using TreeProximity module to identify class-wise outlier samples
raw_outlier_measure = SimEx.raw_outlier_measure()
outlier_measure = SimEx.get_classwise_outlier_measure()
df_outlier_measure = pd.DataFrame({'raw_outlier_measure':raw_outlier_measure,
                                    'outlier_measure':outlier_measure,
                                    'class_label':y}
                                  )
plt.figure()
plt.scatter(x=y,y=outlier_measure)
plt.xlabel('Class Label')
plt.ylabel('Outlier Measure')  
plt.title('Class-wise Outlier Measure')
plt.show()

Contributing

Guidelines for contributing to the project:

License

The license for the project:

Credits

Initial code contributions:

  • Dhruv Desai
  • Dhagash Mehta
  • Julio Urquidi

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rfproximity-0.0.2a1.tar.gz (27.8 kB view details)

Uploaded Source

File details

Details for the file rfproximity-0.0.2a1.tar.gz.

File metadata

  • Download URL: rfproximity-0.0.2a1.tar.gz
  • Upload date:
  • Size: 27.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.21

File hashes

Hashes for rfproximity-0.0.2a1.tar.gz
Algorithm Hash digest
SHA256 e21ddf043679ab99183f7327ca0d4c1ceb13f5c78f75793b4d44760b1e0e6c93
MD5 1149e777acfa6d283382f2010b2631bc
BLAKE2b-256 10c3c3c7186d2872dad956223ec4ba4e036066cbc529686da29a6e19a33c524c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page