Skip to main content

Machine Learning Tool Box

Project description

PyPI version License

Machine Learning Tool Box

This is the machine learning tool box. A collection of userful machine learning tools intended for reuse and extension. The toolbox contains the following modules:

  • hyperopt - Hyperopt tool to save and restart evaluations
  • keras - Keras (tf.keras) callback for various metrics and various other Keras tools
  • lightgbm - metric tool functions for LightGBM
  • metrics - several metric implementations
  • plot - plot and visualisation tools
  • tools - various (i.a. statistical) tools

Module: hyperopt

This module contains a tool function to save and restart Hyperopt evaluations. This is done by saving and loading the hyperopt.Trials objects. The usage looks like this:

from mltb.hyperopt import fmin
from hyperopt import tpe, hp, STATUS_OK


def objective(x):
    return {
        'loss': x ** 2,
        'status': STATUS_OK,
        'other_stuff': {'type': None, 'value': [0, 1, 2]},
        }


best, trials = fmin(objective,
    space=hp.uniform('x', -10, 10),
    algo=tpe.suggest,
    max_evals=100,
    filename='trials_file')

print('best:', best)
print('number of trials:', len(trials.trials))

Output of first run:

No trials file "trials_file" found. Created new trials object.
100%|██████████| 100/100 [00:00<00:00, 338.61it/s, best loss: 0.0007185087453453681]
best: {'x': 0.026805013436769026}
number of trials: 100

Output of second run:

100 evals loaded from trials file "trials_file".
100%|██████████| 100/100 [00:00<00:00, 219.65it/s, best loss: 0.00012259809712488858]
best: {'x': 0.011072402500130158}
number of trials: 200

Module: lightgbm

This module implements metric functions that are not included in LightGBM. At the moment this is the F1- and accuracy-score for binary and multi class problems. The usage looks like this:

bst = lgb.train(param,
                train_data,
                valid_sets=[validation_data]
                early_stopping_rounds=10,
                evals_result=evals_result,
                feval=mltb.lightgbm.multi_class_f1_score_factory(num_classes, 'macro'),
               )

Module: keras (for tf.keras)

BinaryClassifierMetricsCallback

This module provides custom metrics in form of a callback. Because the callback adds these values to the internal logs dictionary it is possible to use the EarlyStopping callback to do early stopping on these metrics.

Parameters

Parameter Description Type Default values
val_data Validation input list
val_label Validation output list
pos_label Which index is the positive label Optional[int] 1
metrics List of supported metric names or custom metric functions List[Union[str, Callable]] ['val_roc_auc', 'val_average_precision', 'val_f1', 'val_acc']

Available metrics

  • val_roc_auc : ROC-AUC
  • val_f1 : F1-score
  • val_acc: Accuracy
  • val_average_precision: Average precision
  • val_mcc: Matthews correlation coefficient

The usage looks like this:

bcm_callback = mltb.keras.BinaryClassifierMetricsCallback(val_data, val_labels)
es_callback = callbacks.EarlyStopping(monitor='val_roc_auc', patience=5,  mode='max')

history = network.fit(train_data, train_labels,
                      epochs=1000,
                      batch_size=128,

                      #do not give validation_data here or validation will be done twice
                      #validation_data=(val_data, val_labels),

                      #always provide BinaryClassifierMetricsCallback before the EarlyStopping callback
                      callbacks=[bcm_callback, es_callback],
)

You can also define your own custom metric:

def custom_average_recall_score(y_true, y_pred, pos_label):
    rounded_pred = np.rint(y_pred)
    return sklearn.metrics.recall_score(y_true, rounded_pred, pos_label)


bcm_callback = mltb.keras.BinaryClassifierMetricsCallback(val_data, val_labels,metrics=[custom_average_recall_score])
es_callback = callbacks.EarlyStopping(monitor='custom_average_recall_score', patience=5,  mode='max')

history = network.fit(train_data, train_labels,
                      epochs=1000,
                      batch_size=128,

                      #do not give validation_data here or validation will be done twice
                      #validation_data=(val_data, val_labels),

                      #always provide BinaryClassifierMetricsCallback before the EarlyStopping callback
                      callbacks=[bcm_callback, es_callback],
)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mltb-0.8.0.tar.gz (14.4 kB view details)

Uploaded Source

Built Distribution

mltb-0.8.0-py3-none-any.whl (14.2 kB view details)

Uploaded Python 3

File details

Details for the file mltb-0.8.0.tar.gz.

File metadata

  • Download URL: mltb-0.8.0.tar.gz
  • Upload date:
  • Size: 14.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.6.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.61.2 CPython/3.9.6

File hashes

Hashes for mltb-0.8.0.tar.gz
Algorithm Hash digest
SHA256 718072ad59d1c1d7a71c2833a96e5800dfbe18fda0f9996a26f9cbddda3815aa
MD5 35e9f2849451711c2e7d4779e41f660f
BLAKE2b-256 a748a5c7e5fa4c15b9ecf1c81a3efc4672075a4d49971b492b96ecb6a11447c0

See more details on using hashes here.

File details

Details for the file mltb-0.8.0-py3-none-any.whl.

File metadata

  • Download URL: mltb-0.8.0-py3-none-any.whl
  • Upload date:
  • Size: 14.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.6.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.61.2 CPython/3.9.6

File hashes

Hashes for mltb-0.8.0-py3-none-any.whl
Algorithm Hash digest
SHA256 188da770e9fa3be12be5a67feffbad95901b50c89bbf50ba0f0fc3bd889567ab
MD5 df734775ba0e55037a90945f68c3f787
BLAKE2b-256 52af00d405cf883b527f4ca737c9d6cae8149df6a6e82cefb251dca38f45e81c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page