Skip to main content

Bayesian target encoding with scikit-learn and scipy

Project description

codecov Maintainability PyPI PyPI - Python Version Documentation Status

Overview

This package is a lightweight implementation of bayesian target encoding. This implementation is taken from Slakey et al., with ensemble methodology from Larionov.

The encoding proceeds as follows:

  1. User observes and chooses a likelihood for the target variable (e.g. Bernoulli for a binary classification problem),
  2. Using Fink's Compendium of Priors, derive the conjugate prior for the likelihood (e.g. Beta),
  3. Use the training data to initialize the hyperparameters for the prior distribution
    • NOTE: This process is generally reliant on common interpretations of hyperparameters.
  4. Using Fink's Compendium, derive the methodology for generating the posterior distribution,
  5. For each level in the categorical variable,
    1. Generate the posterior distribution using the observed target values for the categorical level,
    2. Set the encoding value to a sample from the posterior distribution
      • If a new level has appeared in the dataset, the encoding will be sampled from the prior distribution. To disable this behaviour, initialize the encoder with handle_unknown="error".

Then, we repeat step 5.2 a total of n_estimators times, generating a total of n_estimators training datasets with unique encodings. The end model is a vote from each sampled dataset.

For reproducibility, you can set the encoding value to the mean of the posterior distribution instead.

Installation

Install from PyPI:

python -m pip install bayte

Usage

Encoding

Let's create a binary classification dataset.

import numpy as np
import pandas as pd
from sklearn.datasets import make_classification

X, y = make_classification(n_samples=1000, n_features=5, n_informative=2)
X = pd.DataFrame(X)

# Categorical data
X[5] = np.random.choice(["red", "green", "blue"], size=1000)

Import and fit the encoder:

import bayte as bt

encoder = bt.BayesianTargetEncoder(dist="bernoulli")
encoder.fit(X[[5]], y)

To encode your categorical data,

X[5] = encoder.transform(X[[5]])

Ensemble

If you want to utilize the ensemble methodology described above, construct the same dataset

import numpy as np
import pandas as pd
from sklearn.datasets import make_classification

X, y = make_classification(n_samples=1000, n_features=5, n_informative=2)
X = pd.DataFrame(X)

# Categorical data
X[5] = np.random.choice(["red", "green", "blue"], size=1000)

and import a classifier to supply to the ensemble class

from sklearn.svm import SVC

import bayte as bt

ensemble = bt.BayesianTargetClassifier(
    base_estimator=SVC(kernel="linear"),
    encoder=bt.BayesianTargetEncoder(dist="bernoulli")
)

Fit the ensemble. NOTE: either supply an explicit list of categorical features to categorical_feature, or use a DataFrame with categorical data types.

ensemble.fit(X, y, categorical_feature=[5])

When you call predict on a novel dataset, note that the encoder will transform your data at runtime and it will encode based on the mean of the posterior distribution:

ensemble.predict(X)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

bayte-0.2.1-py3-none-any.whl (13.7 kB view details)

Uploaded Python 3

File details

Details for the file bayte-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: bayte-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 13.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.0.0 CPython/3.12.3

File hashes

Hashes for bayte-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 adad03fdcedf1b1001682c20c8af1f897e916018fd24eeb7418f5b8434557982
MD5 5e4e55f978f4e31a3d15746b874c935d
BLAKE2b-256 af4849f8d137d54f653bac3de0a44dd32f2ccbe49e4fca5e1e8d40da57e3a678

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page