Transparent, Robust & Ultra-Sparse Trees (TRUST) - Free Version
Project description
trust-free
Transparent, Robust & Ultra-Sparse Trees (TRUST) - Free Version
This package provides the free version of our TRUST algorithm, a SOTA interpretable machine learning model.
Currently designed only for regression tasks. Future releases will also tackle other tasks e.g. classification.
Installation
You can install this package using pip:
pip install trust-free
Note: This package includes a precompiled binary and is currently only compatible with macOS 11+ on ARM64 architecture.
Usage
Here are two basic examples of how to use the TRUST algorithm:
from trust import TRUST # note the import name is trust, not trust-free
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error
# Example 1: Generate synthetic sparse regression data
X, y, coefs = make_regression(n_samples=5000, n_features=20, n_informative=10, coef=True, noise=0.1, random_state=123)
print(coefs)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=123)
# Instantiate and fit your model
model = TRUST()
model.fit(X_train, y_train)
# Predict and print results
y_pred = model.predict(X_test)
print("Predictions:", y_pred[:5])
print("True y values:", y_test[:5])
print("test R^2:", r2_score(y_test, y_pred))
# Obtain prediction explanation for first observation
model.explain(X_test[0,:], y_pred[0], actual=y_test[0])
# Obtain (conditional) variable importance by Ghost method (Delicado and Pena, 2023)
model.varImp(X_test, y_test, model, corAnalysis=True)
# Example 2: the famous diabetes dataset (n=442, p=10)
import pandas as pd
from sklearn import datasets
Diabetes = pd.DataFrame(datasets.load_diabetes().data)
Diabetes.columns = datasets.load_diabetes().feature_names
diab_target = datasets.load_diabetes().target
Diabetes.insert(len(Diabetes.columns), "Disease_marker", diab_target)
Diabetes_X = Diabetes.iloc[:,:-1]
Diabetes_y = Diabetes.iloc[:,-1]
RLT_Diabetes = TRUST(max_depth=1)
RLT_Diabetes.fit(Diabetes_X,Diabetes_y)
y_pred_TRUST = RLT_Diabetes.predict(Diabetes_X)
# Tree plotting requires Graphviz to be installed in your system path
# You can use e.g. Homebrew: brew install graphviz or Conda: conda install -c conda-forge graphviz
RLT_Diabetes.plot_tree("Diabetes") #will save "Diabetes.png" in your working directory
# Obtain prediction explanation for first observation
RLT_Diabetes.explain(Diabetes_X.iloc[0,:], y_pred_TRUST[0], actual=Diabetes_y.to_list()[0])
# Obtain variable importance with 2 different methods: Ghost and permutation
RLT_Diabetes.varImp(Diabetes_X, Diabetes_y, RLT_Diabetes, corAnalysis=True) #Ghost method
RLT_Diabetes.varImpPerm(Diabetes_X, Diabetes_y, RLT_Diabetes) #Permutation method
License
This software is provided under a Proprietary - Permissive Binary Only license. For detailed terms, please refer to the LICENSE file included with the distribution.
More Information
For more details, documentation, and information about the full (paid) version of the TRUST algorithm, please visit our official website:
https://adc-trust-ai.github.io/trust/
Further details can be found in our preprint on arxiv:
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file trust_free-0.9.2-cp311-cp311-macosx_11_0_arm64.whl.
File metadata
- Download URL: trust_free-0.9.2-cp311-cp311-macosx_11_0_arm64.whl
- Upload date:
- Size: 606.3 kB
- Tags: CPython 3.11, macOS 11.0+ ARM64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7b65616c621889dd7fb3499248525c79f7ededa0faa3ad75ca642b17a17b855e
|
|
| MD5 |
d57d1f9e123668bb7e8e0f6b581f8625
|
|
| BLAKE2b-256 |
946e37a06d8e608af7db48b0df26e55d9137f25c1dec4da724b1b305c1b92f73
|