Skip to main content

Interpretable machine learning model for binary classification combining deep learning and rule learning

Project description

Developed by M.J. van der Zwart as MSc thesis project (c) 2023

Installation

pip install r2ntab

Preparing data using sample dataset

import torch

from r2ntab import transform_dataset, kfold_dataset

name = 'adult.data'
X, Y, X_headers, Y_headers = transform_dataset(name, method='onehot-compare',
negations=False, labels='binary')
datasets = kfold_dataset(X, Y, shuffle=1)
X_train, X_test, Y_train, Y_test = datasets[0]
train_set = torch.utils.data.TensorDataset(torch.Tensor(X_train.to_numpy()),
torch.Tensor(Y_train))

Creating and training the model

from r2ntab import R2NTab

model = R2NTab(len(X_headers), 10, 1)
model.fit(train_set, epochs=1000)
Y_pred = model.predict(X_test)

Extracting the results

rules = model.extract_rules(X_headers, print_rules=True)
print(f'AUC: {model.score(Y_pred, Y_test, metric="auc")}')
print(f'# Rules: {len(rules)}')
print(f'# Conditions: {sum(map(len, rules))}')

Contact

For any questions or problems, please open an issue here on GitHub.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

r2ntab-1.0.3.tar.gz (13.5 kB view details)

Uploaded Source

File details

Details for the file r2ntab-1.0.3.tar.gz.

File metadata

  • Download URL: r2ntab-1.0.3.tar.gz
  • Upload date:
  • Size: 13.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.10

File hashes

Hashes for r2ntab-1.0.3.tar.gz
Algorithm Hash digest
SHA256 5f4ed3a21f7409104e13451f86b5d2985a783553f2ce21d75ca0fe4cad4906ae
MD5 bc05b67ba1b66c9668b88ab64e77cdff
BLAKE2b-256 b27d3b02a5235ccbc8466f1e524abc1035b094bdba1c64742e1bc50a011fec98

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page