Skip to main content

Build explainable ML models using surrogate models.

Project description

# SAFE - Surrogate Assisted Feature Extraction

SAFE is a python library that you can use to enhance your simple ML models. The idea is to use more complicated model - called surrogate model - to extract more information from features, which can be used later to fit some simpler model. Input data is divided into subsets, determined by surrogate model, and then it is transformed based on the subset each point belonged to. Library provides you with SafeTransformer class, which implements TransformerMixin interface, so it can be used as a part of the scikit-learn pipeline.

## Requirements

To install this library run:

` pip install safe-transformer `

The only requirement is to have Python 3 installed on your machine.

## Usage with example

Sample code using SAFE transformer as part of scikit-learn pipeline:

   ![](images/note.png)   

As you can see you can improve your simple model performance with help of the more complicated model.

You can use any model you like, as long as it has fit and predict methods in case of regression, or fit and predict_proba in case of classification. Data used to fit SAFE transformer needs to be pandas data frame.

You can also specify penalty and pelt model arguments.

In examples folder you can find jupyter notebooks with complete classification and regression examples.

## Algorithm

Our goal is to divide each feature into subsets and then transform feature values based on the subset they belong to. The division is based on the response of the surrogate model. In case of continuous dependent variables for each of them we find changepoints - points that indicate values of variable for which the response of the surrogate model changes quickly. Intervals between changepoints are the basis of the transformation, eg. feature is transformed to categorical variable, where feature values in the same interval form the same category. To find changepoints we need partial dependence plots. These plots are graphical visualizations of the marginal effect of a given variable (or multiple variables) on an outcome of the model. In case of categorical variables for each of them we perform hierarchical clustering based on surrogate model responses.

Algorithm for performing fit method is illustrated below:

  

![Fit method algorithm](images/fl.svg)

  

Here is example of partial dependence plot. It was created for boston housing data frame, variable in example is LSTAT. To get changepoints from partial dependence plots we use ruptures library and its model Pelt.

| |
- | - |
![alt](images/simple-plot.png) | ![alt](images/changepoint.png) |

Our algorithm works both for regression and classification problems. In case of regression we simply use model response for creating partial dependence plot and hierarchical clustering. As for classification we use predicted probabilities of each class.

## Model optimization

One of the parameters you can specify is penalty - it has an impact on the number of changepoints that will be created. Here you can see how the quality of the model changese with penalty. For reference results of surrogate and basic model are also in the plot.

&nbsp;&nbsp; <img src=”images/pens.png” alt=”Model performance” width=”500”/> &nbsp;&nbsp;

With correctly chosen penalty your simple model can achieve much better accuracy, close to accuracy of surrogate model.

## References

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

safe-transformer-0.0.1.tar.gz (37.1 kB view details)

Uploaded Source

Built Distribution

safe_transformer-0.0.1-py3-none-any.whl (7.6 kB view details)

Uploaded Python 3

File details

Details for the file safe-transformer-0.0.1.tar.gz.

File metadata

  • Download URL: safe-transformer-0.0.1.tar.gz
  • Upload date:
  • Size: 37.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.9.1 setuptools/39.1.0 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.5.2

File hashes

Hashes for safe-transformer-0.0.1.tar.gz
Algorithm Hash digest
SHA256 a40e9166314172a6182c9ee21efeb19a51605aff0c996eda1f7c2394e173e832
MD5 4796d2ad1a48937351a126840d4cac14
BLAKE2b-256 05a173520215df36fa5e6fc406b86cd886e139bacbdf0e48df0606cd7bd0eeaf

See more details on using hashes here.

File details

Details for the file safe_transformer-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: safe_transformer-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 7.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.9.1 setuptools/39.1.0 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.5.2

File hashes

Hashes for safe_transformer-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 4dd1bb3781564adf395724192719c1d13eaaa93a0d385b33ff7180e93bbd5846
MD5 5f830741f0fb237840a427157797cf80
BLAKE2b-256 1f1df0ac676787363cf36f977ac8b284b4dac562738f82f08461198c1e02c021

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page