Skip to main content

Interpretable machine learning model for binary classification combining deep learning and rule learning

Project description

Developed by M.J. van der Zwart as MSc thesis project (c) 2023

Installation

pip install r2ntab

Preparing data using sample dataset

import torch

from r2ntab import transform_dataset, kfold_dataset

name = 'adult.data'
X, Y, X_headers, Y_headers = transform_dataset(name, method='onehot-compare',
negations=False, labels='binary')
datasets = kfold_dataset(X, Y, shuffle=1)
X_train, X_test, Y_train, Y_test = datasets[0]
train_set = torch.utils.data.TensorDataset(torch.Tensor(X_train.to_numpy()),
torch.Tensor(Y_train))

Creating and training the model

from r2ntab import R2NTab

model = R2Ntab(len(X_headers), 10, 1)
model.fit(train_set, epochs=1000)
Y_pred = model.predict(X_test)

Extracting the results

rules = model.extract_rules(X_headers, print_rules=True)
print(f'AUC: {model.score(Y_pred, Y_test, metric="auc")}')
print(f'# Rules: {len(rules)}')
print(f'# Conditions: {sum(map(len, rules))}')

Contact

For any questions or problems, please open an issue here on GitHub.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

r2ntab-1.0.2.tar.gz (13.5 kB view details)

Uploaded Source

Built Distribution

r2ntab-1.0.2-py3-none-any.whl (15.2 kB view details)

Uploaded Python 3

File details

Details for the file r2ntab-1.0.2.tar.gz.

File metadata

  • Download URL: r2ntab-1.0.2.tar.gz
  • Upload date:
  • Size: 13.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.10

File hashes

Hashes for r2ntab-1.0.2.tar.gz
Algorithm Hash digest
SHA256 8d914f8dd8622a43e32182d7d89f33af125de526d328700a96a6f10ade9cc32f
MD5 f0233cda1076860af371292058d40ff3
BLAKE2b-256 c52266163e5a9f8c21b3b2de41b41c1f98de5441918d983a8971ea704de7c83f

See more details on using hashes here.

File details

Details for the file r2ntab-1.0.2-py3-none-any.whl.

File metadata

  • Download URL: r2ntab-1.0.2-py3-none-any.whl
  • Upload date:
  • Size: 15.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.10

File hashes

Hashes for r2ntab-1.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 c3f085e17df8ead77ce8651a689f188658c98bd0e71a22b12580c5791c84758a
MD5 35dda02a733abe07dee4157be30708a7
BLAKE2b-256 ba8081b205bf3cff632a247f772b27bc709a4568bf53030c501f240a36427531

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page