A simple tool to embed scikit-learn models into microcontrollers
Project description
Machine Learning for Embbedded Devices
sklearn2c is a tool that converts scikit-learn library classification algorithms to C code. It can be used to generate C code from trained models, which can then be used in microcontrollers or other embedded systems. The generated code can be used for real-time classification tasks, where the computational resources are limited.
Supported Models
Classification
-
Bayes Classifier*
-
Decision Trees
-
KNN Classifier
-
C-SVC**
*: sklearn2c does not use scikit-learn
GaussianNB()
, instead it uses the following cases to compute decision function.**:
linear
,poly
andrbf
kernels are supported.
Regression
- Linear Regression
- Polynomial Regression
- KNN
- Decision Trees
Clustering
- kmeans
- DBSCAN
Installation
You can install the library via pip either using:
pip install sklearn2c
or
pip install git+git@github.com:EmbeddedML/sklearn2c.git
Alternatively, you can install conda package:
conda install sklearn2c
or mamba install sklearn2c
Usage
Please check examples
directory under this repository. For example, decision tree classifier is created as follows:
train
method trains the model and optionally saves the model file tosave_path
. This method is totally compatible with scikit-learn library.predict
method runs the model on the given data.- static method
load
loads the model from saved path. export
method generates model parameters as C functions.
dtc = DTClassifier()
dtc.train(train_samples, train_labels, save_path="<path/to/model>")
dtc.predict(test_samples)
dtc2 = DTClassifier.load(dtc_model_dir)
dtc2.export("<path/to/config_dir>")
For inference on C(specifically for STM32 boards), you can take a look at STM32_inference
directory for the corresponding model.
License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for sklearn2c-0.0.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3d23bff3a3d39aeafdeb4d2335588126b0eef5a67ac8be272bd262dbb775f68c |
|
MD5 | 4b8f4df888b7cf2f6c4d069cf3bdbbf3 |
|
BLAKE2b-256 | 779a9eed7e65e91e0f67a53538aac75419e71356f3c01e021f2a48d3c25083f8 |