EBM model serialization to ONNX
Project description
Ebm2onnx converts EBM models to ONNX. It allows to run an EBM model on any ONNX compliant runtime.
Features
Binary classification
Regression
Continuous variables
Categorical variables
Interactions
Multi-class classification (support is still experimental in EBM)
Expose local explanations
The export of the models is tested against ONNX Runtime.
Get Started
Train an EBM model:
# prepare dataset
df = pd.read_csv('titanic_train.csv')
df = df.dropna()
feature_columns = ['Age', 'Fare', 'Pclass', 'Embarked']
label_column = "Survived"
y = df[[label_column]]
le = LabelEncoder()
y_enc = le.fit_transform(y)
x = df[feature_columns]
x_train, x_test, y_train, y_test = train_test_split(x, y_enc)
# train an EBM model
model = ExplainableBoostingClassifier(
feature_types=['continuous', 'continuous', 'continuous','categorical'],
)
model.fit(x_train, y_train)
Then you can convert it to ONNX in a single function call:
import onnx
import ebm2onnx
onnx_model = ebm2onnx.to_onnx(
model,
ebm2onnx.get_dtype_from_pandas(x_train),
)
onnx.save_model(onnx_model, 'ebm_model.onnx')
If your dataset is not a pandas dataframe, you can provide the features’ types directly:
import ebm2onnx
onnx_model = ebm2onnx.to_onnx(
model,
dtype={
'Age': 'double',
'Fare': 'double',
'Pclass': 'int',
}
)
onnx.save_model(onnx_model, 'ebm_model.onnx')
Try it live
You can live test the model conversion.
You can live test local explanations.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file ebm2onnx-1.2.0.tar.gz
.
File metadata
- Download URL: ebm2onnx-1.2.0.tar.gz
- Upload date:
- Size: 9.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 099f79676ec86d255c55cbaad1c9e1ad8435b5962cca6965bf0155790e228ed2 |
|
MD5 | 2dc3a59b1757023e900a581ceb198b81 |
|
BLAKE2b-256 | 559f3d7a6525ec449af0689c8282d3e7dcfd07233d8c789174f519003aff44d3 |