Skip to main content

Sklearn compatiable model instance labelling tool to help validate models in situations involving data drift.

Project description

adversarial_labeller


Adversarial labeller is a sklearn compatible labeller that scores instances as belonging to the test dataset or not to help model selection under data drift. Adversarial labeller is distributed under the MIT license.

Installation

Dependencies

Adversarial validator requires:

  • Python (>= 3.7)
  • scikit-learn (>= 0.21.0)
  • [imbalanced learn](>= 0.5.0)
  • [pandas](>= 0.25.0)

User installation

The easiest way to install adversarial validator is using

pip install adversarial_labeller

Example Usage

import numpy as np
import pandas as pd
from sklearn.datasets.samples_generator import make_blobs
from sklearn.metrics import accuracy_score
from sklearn.model_selection import cross_val_score
from sklearn.ensemble import RandomForestClassifier
from adversarial_labeller import AdversarialLabelerFactory, Scorer

scoring_metric = accuracy_score

# Our blob data generation parameters for this example
number_of_samples = 1000
number_of_test_samples = 300

# Generate 1d blob data and label a portion as test data
# ... 1d blob data can be visualized as a rug plot
variables, labels = \
make_blobs(
    n_samples=number_of_samples,
    centers=2,
    n_features=1,
    random_state=0
)

df = pd.DataFrame(
{
    'independent_variable':variables.flatten(),
    'dependent_variable': labels,
    'label': 0  #  default to train data
}
)
test_indices = df.index[-number_of_test_samples:]
train_indices = df.index[:-number_of_test_samples]

df.loc[test_indices,'label'] = 1  # ... now we mark instances that are test data

# Now perturb the test samples to simulate data drift/different test distribution
df.loc[test_indices, "independent_variable"] +=\
np.std(df.independent_variable)

# ... now we have an example of data drift where adversarial labeling can be used to better estimate the actual test accuracy

features_for_labeller = df.independent_variable
labels_for_labeller = df.label

pipeline, flip_binary_predictions =\
    AdversarialLabelerFactory(
        features = features_for_labeller,
        labels = labels_for_labeller,
        run_pipeline = False
    ).fit_with_best_params()

scorer = Scorer(the_scorer=pipeline,
                flip_binary_predictions=flip_binary_predictions)

# Now we evaluate a classifer on training data only, but using
# our fancy adversarial labeller
_X = df.loc[train_indices]\
    .independent_variable\
    .values\
    .reshape(-1,1)

_X_test = df.loc[test_indices]\
            .independent_variable\
            .values\
            .reshape(-1,1)

# ... sklearn wants firmly defined shapes
clf_adver = RandomForestClassifier(n_estimators=100, random_state=1)
adversarial_scores =\
    cross_val_score(
        X=_X,
        y=df.loc[train_indices].dependent_variable,
        estimator=clf_adver,
        scoring=scorer.grade,
        cv=10,
        n_jobs=-1,
        verbose=1)
# ... and we get ~ 0.70 - 0.68
average_adversarial_score =\
    np.array(adversarial_scores).mean()

# ... let's see how this compares with normal cross validation
clf = RandomForestClassifier(n_estimators=100, random_state=1)
scores =\
    cross_val_score(
        X=_X,
        y=df.loc[train_indices].dependent_variable,
        estimator=clf,
        cv=10,
        n_jobs=-1,
        verbose=1)

# ... and we get ~ 0.92
average_score =\
    np.array(scores).mean()

# now let's see how this compares with the actual test score
clf_all = RandomForestClassifier(n_estimators=100, random_state=1)
clf_all.fit(_X,
            df.loc[train_indices].dependent_variable)

# ... actual test score is 0.70
actual_score =\
accuracy_score(
    clf_all.predict(_X_test),
    df.loc[test_indices].dependent_variable
)

adversarial_result = abs(average_adversarial_score - actual_score)
print(f"... adversarial labelled cross validation was {adversarial_result:.2f} points different than actual.")  # ... 0.00 - 0.02 points

cross_val_result = abs(average_score - actual_score)
print(f"... regular validation was {cross_val_result:.2f} points different than actual.")  # ... 0.23 points

#  See tests/ for additional examples, including against the Titanic and stock market trading

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

adversarial labeller-0.1.8.tar.gz (6.5 kB view details)

Uploaded Source

Built Distribution

adversarial_labeller-0.1.8-py3-none-any.whl (7.9 kB view details)

Uploaded Python 3

File details

Details for the file adversarial labeller-0.1.8.tar.gz.

File metadata

  • Download URL: adversarial labeller-0.1.8.tar.gz
  • Upload date:
  • Size: 6.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.4.0 requests-toolbelt/0.9.1 tqdm/4.41.0 CPython/3.7.2

File hashes

Hashes for adversarial labeller-0.1.8.tar.gz
Algorithm Hash digest
SHA256 5cdc9b594cf38e6e21eb4acc8d77dda6ef4882672ff790e67e3965e2319c76d7
MD5 ef04d8e76e952ea4c876a08b0e3c642e
BLAKE2b-256 5e260677577be4d9a6d9907cf32dc1c850ae15e4bb8e41bcf6583aaec14ade1e

See more details on using hashes here.

File details

Details for the file adversarial_labeller-0.1.8-py3-none-any.whl.

File metadata

  • Download URL: adversarial_labeller-0.1.8-py3-none-any.whl
  • Upload date:
  • Size: 7.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.4.0 requests-toolbelt/0.9.1 tqdm/4.41.0 CPython/3.7.2

File hashes

Hashes for adversarial_labeller-0.1.8-py3-none-any.whl
Algorithm Hash digest
SHA256 281ff665aa6bbda4843ba734487860cdd48aa7bd49c0d15db40551d2167b0f54
MD5 0aa802592a8bdd4e55a3e468d17d348b
BLAKE2b-256 a629df7a3ec3812142575459f90a3278167757e3852f16ebdda4a3ab6864de63

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page