Skip to main content

A library to construct concept drift adaptive model services.

Project description

musc

musc (model update strategy construction) is a Python library aiming to help users to construct concept drift adaptive model services.

Note: SemVer is not yet followed in 0.1.x versions.

Preliminary

Q: What is concept drift?

A: Concept drift is the phenomenon that data distribution received by the deployed model changes over time. Machine learning algorithms usually assume that data received online has the same distribution as in the training set, so concept drift will make model performance become worse.

Q: What is model update strategy?

A: In this library, model update strategy determines when and how the model should be updated to adapt to the new data distribution, and is an essential part of our concept drift adaptive model services. For the "when" problem, the model update can be periodically, or based on concept drift detection algorithms. For the "how" problem, the model update can be done by retraining or fine-tuning using the latest labelled data received online.

Usage

musc can be installed by pip install musc. musc requires Python 3.11 and above.

Constructing model service

The simplest way to construct concept drift adaptive model service is to use the musc.run_service_http function:

import musc.high_level as musc
import torch

model = torch.nn.Linear(2, 2)
musc.run_service_http(model, musc.UpdateStrategyBlank())

A model object and a model update strategy object is required. For model, currently supported model types include PyTorch's torch.nn.Module and scikit-learn's sklearn.base.BaseEstimator. If your model does not belong to any of these, you can implement your musc.BaseModel subclass and use it. For model update strategy, an musc.UpdateStrategy(Blank|ByPeriodicallyUpdate|ByDriftDetection) object is expected. A "blank" strategy is used here for simplicity, which means that the model will never be updated, and more details about mode update strategy will be covered later.

The constructed model service can be called by cURL:

$ curl -X POST localhost -H 'Content-Type: application/json' -d '{ "x": [0, 1], "id": 0 }'
{"y_pred":[-0.9937736988067627,0.016456782817840576]}
$ curl -X POST localhost -H 'Content-Type: application/json' -d '{ "y": [0, 0], "id": 0 }'
{}

If called with an input value, then the prediction will be returned. If called with a ground truth label, then the label will be passed to the model update strategy. The "id" argument is required for the model service to identify the corresponding input value when receiving a label.

If you wish to embed the constructed model service into other HTTP/GRPC/... services instead of making it an independent one, you can use the musc.Service class instead of the musc.run_service_http function:

svc = musc.Service(model, musc.UpdateStrategyBlank())

y_pred = svc.recv_x(torch.tensor([0.0, 1.0]), id_=0)  # tensor([0.4906, 0.5624])
_      = svc.recv_y(torch.tensor([0.0, 0.0]), id_=0)

Constructing model update strategy

In addition to the "blank" strategy, an easy-to-understand kind of model update strategy is the periodically update strategy:

def updator(model, x_arr, y_arr, lr=1e-4):
    ...  # In-place retraining or fine-tuning

strategy = musc.UpdateStrategyByPeriodicallyUpdate(period=32, updator=updator)

To construct this kind of strategy, you need to specify update period by number of samples and a "model updator" object, which accepts three arguments: the model object that should be updated in-place, and the x and y array of the latest labelled data received by the model service, whose length equal to the update period. See examples/svc_basic.py for an example of model updator.

A more complex kind of model update strategy is the drift detection update strategy:

from river.drift import ADWIN  # A drift detection algorithm

metric = musc.Metric(torch.nn.functional.mse_loss, pred_first=True)

strategy = musc.UpdateStrategyByDriftDetection(
    drift_detector=ADWIN(),
    metric=metric,
    updator=updator,
    data_amount_required=32,
)

To construct this kind of strategy, instead of specifying the update period, you need to specify a concept drift detector from River and a musc.Metric object which calculates metric value for each pair of prediction and label, which is needed by the drift detector. You also need to tell the model service how much data is needed within a model update by the data_amount_required argument, so that after the drift detector reports a drift, the model service can collect an appropriate amount of labelled data for your model updator.

Evaluation of model update strategy

To make a comparison between model update strategies before model service construction, you can perform evaluation for your model update strategy candidates:

import pandas as pd

test_data = pd.read_csv(...)
x_arr, y_arr, t_x_arr, t_y_arr = test_data[...].to_numpy(), ...

eval_ = musc.Evaluator(model, strategy)
mse, stats = eval_.evaluate(x_arr, y_arr, t_x_arr, t_y_arr, metric)

print(mse)
print(stats.model_update_cost_by_time())
print(stats.model_update_cost_by_num_updates())
print(stats.model_update_cost_by_num_samples())

To perform evaluation, you need to provide test data that can represent what the model service will receive. Test data should contain input values and ground truth labels, along with their arrival timestamps. Evaluation will give you information about model performance and model update cost, allowing you to make a good trade-off between the two. See examples/eval.py for more details about model update strategy evaluation.

Hyperparameter optimization of model update strategy

If you are not familiar with the field of concept drift and have no idea about how to select the potentially ideal model update strategy type (for example, by periodically update or by drift detection, by retraining or by fine-tuning) and parameter values (such as period of periodically update and learning rate of fine-tuning), you can perform hyperparameter optimization to find model update strategies that work well in the target scenario, without the need to understand how drift detectors work:

search_space = {
    'type': musc.UpdateStrategyByDriftDetection,
    'metric': metric,
    'updator': {
        'base_fn': updator,
        'lr': [1e-4, 1e-3, 1e-2, 1e-1],
    },
    'data_amount_required': [16, 32, 64, 128],
}

if __name__ == '__main__':
    search = musc.UpdateStrategySearch(
        search_space, model, x_arr, y_arr, t_x_arr, t_y_arr,
        metric=[metric, 'time'], optim_mode=['min', 'min'],
        top_k_scores_csv_path='top_k_scores.csv',
        top_k_samples_file_path='top_k_samples.txt',
    )
    search.search(10000)

Note that in the above example the search space of drift detector is not specified. In case of this a reasonable "default search space" will be used, so that the ideal drift detector can be searched without user knowledge. During the search, the evaluated model update strategy samples and their metric values will be outputted to files, sorted by the first metric specified. See examples/hpo.py for more details about model update strategy hyperparameter optimization.

Highlights

  • The constructed model service is robust against abnormal cases like ground truth absence.

  • The evaluation is more accurate than existing approaches.

  • The hyperparameter optimization supports utilizing multiple GPUs.

Acknowledgement

This project is supported by the National Key R&D Program of China (2021YFB1715200).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

musc-0.1.2.tar.gz (30.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

musc-0.1.2-py3-none-any.whl (40.6 kB view details)

Uploaded Python 3

File details

Details for the file musc-0.1.2.tar.gz.

File metadata

  • Download URL: musc-0.1.2.tar.gz
  • Upload date:
  • Size: 30.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: pdm/2.20.1 CPython/3.13.0 Windows/11

File hashes

Hashes for musc-0.1.2.tar.gz
Algorithm Hash digest
SHA256 53721c0cb09cf4010416d14820b37c6a9ed52a7543ae758c992c83e831788f50
MD5 2ec5d6622796ab8254de7687d943693b
BLAKE2b-256 9712da4c3665bee6db6e48daa3e47d44a940378a2d58aacb30f12dd44bb01caa

See more details on using hashes here.

File details

Details for the file musc-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: musc-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 40.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: pdm/2.20.1 CPython/3.13.0 Windows/11

File hashes

Hashes for musc-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 11bc413a6e0db4c8537e4d0f218ad077109ad4fcf8f9204c2fc20f963249ac70
MD5 36f9e2b097ba23d15d34330991bf32ae
BLAKE2b-256 cab33ad603a5123993c928cce03006b337e2db7c78dda42837af44829ce756b8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page