Skip to main content

RFProximity

Project description

RFProximity

## Introduction

RFProximity is a python package which computes proximity matrix for any Random Forest model. Package includes three specific implementations of proximity. Additionally, using the proximity matrix, package includes methods to perform missing value imputaion, outlier detection and prototype identification.

Table of Contents

Background

Reference Material:

  • Leo Brieman and Adele Cutler's blog on Random Forest
  • Geometry and Accuracy Preserving (GAP) Proximities Arxiv Link

Usage

from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from rfproximity import TreeProximity, SimilarityExplainer
import matplotlib.pyplot as plt
import pandas as pd

# Using Iris dataset 
# using the function `sklearn.datasets.load_iris` from
# scikit-learn. The dataset will contain:
# - 3 classes 
# - n samples

data = load_iris()
X,y = data['data'], data['target']
print(X.shape)


# Using the dataset we now train a random forest classifier first
model = RandomForestClassifier().fit(X,y)
leaf_nodes = model.apply(X)


# After training the model, initialize the TreeProximity module
prox = TreeProximity(leaf_nodes)

# Using TreeProximity module calculate proximity matrix
proximity_matrix = prox.proximity_matrix()
print(proximity_matrix)

# Using TreeProximity module calculate out-of-bag proximity matrix
proximity_matrix_oob = prox.proximity_matrix_oob(model,X.shape[0])
print(proximity_matrix_oob)


# Using TreeProximity module calculate geometry and accuracy preserving
# proximity matrix
proximity_matrix_gap = prox.proximity_matrix_gap(model,X.shape[0])
print(proximity_matrix_gap)


# Using TreeProximity module to identify prototype samples

SimEx = SimilarityExplainer(proximity_matrix, y)
prototypes = SimEx.get_prototype(top_k=20, total_prototypes=3, return_neighbors=True)
print(prototypes)


# Using TreeProximity module to identify class-wise outlier samples
raw_outlier_measure = SimEx.raw_outlier_measure()
outlier_measure = SimEx.get_classwise_outlier_measure()
df_outlier_measure = pd.DataFrame({'raw_outlier_measure':raw_outlier_measure,
                                    'outlier_measure':outlier_measure,
                                    'class_label':y}
                                  )
plt.figure()
plt.scatter(x=y,y=outlier_measure)
plt.xlabel('Class Label')
plt.ylabel('Outlier Measure')  
plt.title('Class-wise Outlier Measure')
plt.show()

Contributing

Guidelines for contributing to the project:

License

The license for the project:

Credits

Initial code contributions:

  • Dhruv Desai
  • Dhagash Mehta
  • Julio Urquidi
  • Lucas Ou

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rfproximity-0.0.3.tar.gz (27.8 kB view details)

Uploaded Source

File details

Details for the file rfproximity-0.0.3.tar.gz.

File metadata

  • Download URL: rfproximity-0.0.3.tar.gz
  • Upload date:
  • Size: 27.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.21

File hashes

Hashes for rfproximity-0.0.3.tar.gz
Algorithm Hash digest
SHA256 85a432594138108d30fe9285f7da025c2b6db6dd09311fdf2ac061583c7c2762
MD5 06c658a08faf7d37b26a1d356a71dda0
BLAKE2b-256 528319d70cd487083fef186cbf0acbe4fe5a1fca058a72d8aeeb7888554c36fa

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page