Skip to main content

A metadata-saving proxy for scikit-learn etimators.

Project description

destimator makes it easy to store trained scikit-learn estimators together with their metadata (training data, package versions, performance numbers etc.). This makes it much safer to store already-trained classifiers/regressors and allows for better reproducibility (see this talk by Alex Gaynor for some rationale).

Specifically, the DescribedEstimator class proxies most calls to the original Estimator it is wrapping, but also contains the following information:

  • training and test (validation) data (features_train, labels_train, features_test, labels_test)

  • creation date (created_at)

  • feature names (feature_names)

  • performance numbers on the test set (precision, recall, fscore, support via sklearn)

  • distribution info (distribution_info; python distribution and versions of all installed packages)

  • VCS hash (vcs_hash, if used inside a git repository, otherwise and empty string).

An instantiated DescribedEstimator can be easily serialized using the .save() method and deserialized using either .from_file() or .from_url(). Did you ever want to store your models in S3? Now it’s easy!

DescribedEstimator can be used as follows:

import numpy as np
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.cross_validation import train_test_split
from sklearn.metrics import precision_recall_fscore_support

from destimator import DescribedEstimator


# get some data
iris = load_iris()
features = iris.data
labels = iris.target
features_train, features_test, labels_train, labels_test = train_test_split(features, labels, test_size=0.1)

# train an estimator as usual (in this case a RandomForestClassifier)
clf = RandomForestClassifier(n_estimators=10, max_depth=None, min_samples_split=10, random_state=0)
clf.fit(features_train, labels_train)

# wrap the estimator in the DescribedEstimator class giving it all the training and test (validation) data
dclf = DescribedEstimator(
    clf,
    features_train,
    labels_train,
    features_test,
    labels_test,
    iris.feature_names
)

Now you can use the classifier as usual:

print(dclf.predict(features_test))
> [2 1 2 2 0 1 0 2 2 1 2 0 2 1 2]

and you can also access a bunch of other properties, such as the training data you supplied:

print(dclf.feature_names)
> ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']

print(dclf.features_test)
> [[ 6.3  2.8  5.1  1.5]
   [ 5.6  3.   4.5  1.5]
   [ 6.7  3.1  5.6  2.4]
   [ 6.   2.7  5.1  1.6]
   [ 4.9  3.1  1.5  0.1]
   [ 6.2  2.2  4.5  1.5]
   [ 4.7  3.2  1.6  0.2]
   [ 6.9  3.1  5.1  2.3]
   [ 7.7  2.6  6.9  2.3]
   [ 5.8  2.6  4.   1.2]
   [ 7.2  3.   5.8  1.6]
   [ 5.4  3.7  1.5  0.2]
   [ 7.2  3.2  6.   1.8]
   [ 6.3  3.3  4.7  1.6]
   [ 6.8  3.2  5.9  2.3]]

print(dclf.labels_test)
> [2 1 2 1 0 1 0 2 2 1 2 0 2 1 2]

the performance numbers:

print('precision: %s' % (dclf.precision))
> precision: [1.0, 1.0, 0.875]

print('recall:    %s' % (dclf.recall))
> recall:    [1.0, 0.8, 1.0]

print('fscore:    %s' % (dclf.fscore))
> fscore:    [1.0, 0.888888888888889, 0.9333333333333333]

print('support:   %s' % (dclf.support))
> support:   [3, 5, 7]

print('roc_auc:   %s' % (dclf.roc_auc))
> roc_auc:   0.5

or information about the Python distribution used for training:

from pprint import pprint
pprint(dclf.distribution_info)

> {'packages': ['appnope==0.1.0',
                'decorator==4.0.4',
                'destimator==0.0.0.dev3',
                'gnureadline==6.3.3',
                'ipykernel==4.2.1',
                'ipython-genutils==0.1.0',
                'ipython==4.0.1',
                'ipywidgets==4.1.1',
                'jinja2==2.8',
                'jsonschema==2.5.1',
                'jupyter-client==4.1.1',
                'jupyter-console==4.0.3',
                'jupyter-core==4.0.6',
                'jupyter==1.0.0',
                'markupsafe==0.23',
                'mistune==0.7.1',
                'nbconvert==4.1.0',
                'nbformat==4.0.1',
                'notebook==4.0.6',
                'numpy==1.10.1',
                'path.py==8.1.2',
                'pexpect==4.0.1',
                'pickleshare==0.5',
                'pip==7.1.2',
                'ptyprocess==0.5',
                'pygments==2.0.2',
                'pyzmq==15.1.0',
                'qtconsole==4.1.1',
                'requests==2.8.1',
                'scikit-learn==0.17',
                'scipy==0.16.1',
                'setuptools==18.2',
                'simplegeneric==0.8.1',
                'terminado==0.5',
                'tornado==4.3',
                'traitlets==4.0.0',
                'wheel==0.24.0'],
   'python': '3.5.0 (default, Sep 14 2015, 02:37:27) \n'
             '[GCC 4.2.1 Compatible Apple LLVM 6.1.0 (clang-602.0.53)]'}

Finally, the object can be serialized to a zip file containing all the above data:

dclf.save('./classifiers', 'dclf')

and deserialized either from a file,

dclf = DescribedEstimator.from_file('./classifiers/dclf.zip')

or from a URL:

dclf = DescribedEstimator.from_url('http://localhost/dclf.zip')

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

destimator-0.1.4.tar.gz (10.0 kB view details)

Uploaded Source

Built Distribution

destimator-0.1.4-py2.py3-none-any.whl (12.3 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file destimator-0.1.4.tar.gz.

File metadata

  • Download URL: destimator-0.1.4.tar.gz
  • Upload date:
  • Size: 10.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No

File hashes

Hashes for destimator-0.1.4.tar.gz
Algorithm Hash digest
SHA256 501f495c92563c496eb98740acbc20fe5aadfdc82451dfa47064f3b0a8826ca6
MD5 9fe366d05461f9f279119351219468e3
BLAKE2b-256 3f3fe000d7b54eeb61d3acb299cf6ed6b7fa3ab71f96a4e0880842f9832eaf24

See more details on using hashes here.

File details

Details for the file destimator-0.1.4-py2.py3-none-any.whl.

File metadata

File hashes

Hashes for destimator-0.1.4-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 0f0cf7f9e29f528985ec229dd66a3ab1fd061b12b60d022eae9c4afccb12f7d2
MD5 da6e0f07259fbc72882d5cb55c6d9a58
BLAKE2b-256 f273bcda7d41a4148270528e95da820126ce7eca160a4ed930b28bf47a20af96

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page