Interpretable machine learning model for binary classification combining deep learning and rule learning
Project description
Developed by M.J. van der Zwart as MSc thesis project (c) 2023
Preparing data
import torch
from r2ntab import transform_dataset, kfold_dataset
name = 'adult.data'
X, Y, X_headers, Y_headers = transform_dataset(name, method='onehot-compare',
negations=False, labels='binary')
datasets = kfold_dataset(X, Y, shuffle=1)
X_train, X_test, Y_train, Y_test = datasets[0]
train_set = torch.utils.data.TensorDataset(torch.Tensor(X_train.to_numpy()),
torch.Tensor(Y_train))
Creating and training the model
from r2ntab import R2NTab
model = R2Ntab(len(X_headers), 10, 1)
model.fit(train_set, epochs=1000)
Y_pred = model.predict(X_test)
Extracting the results
rules = model.extract_rules(X_headers, print_rules=True)
print(f'AUC: {model.score(Y_pred, Y_test, metric="auc")}')
print(f'# Rules: {len(rules)}')
print(f'# Conditions: {sum(map(len, rules))}')
Contact
For any questions of problems, please open an issue here on GitHub.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
r2ntab-1.0.0.tar.gz
(13.4 kB
view details)
Built Distribution
r2ntab-1.0.0-py3-none-any.whl
(15.2 kB
view details)
File details
Details for the file r2ntab-1.0.0.tar.gz
.
File metadata
- Download URL: r2ntab-1.0.0.tar.gz
- Upload date:
- Size: 13.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.8.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0bf8fb1acb52e7ee90e308a08faa09e7d75f0a29cfc36d6d201ec4eb20260013 |
|
MD5 | 711e50d79c89faa79bba56476774cc1b |
|
BLAKE2b-256 | afd95bf7a36537edec1244eb82ccfa894d41b2ca556fa4d733feacdc3215091b |
File details
Details for the file r2ntab-1.0.0-py3-none-any.whl
.
File metadata
- Download URL: r2ntab-1.0.0-py3-none-any.whl
- Upload date:
- Size: 15.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.8.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b5d256d5d0f698db9974b4bff41de11eff115e0da277d0f04d6ad1555975724a |
|
MD5 | dae0f3a409034884377e953431d8da4e |
|
BLAKE2b-256 | 8783a90a185162b0b403ce7dc4982d09083102e28ec19b92cec78961547a7394 |