Skip to main content

Interpretable machine learning model for binary classification combining deep learning and rule learning

Project description

Developed by M.J. van der Zwart as MSc thesis project (c) 2023

Preparing data using sample dataset

import torch

from r2ntab import transform_dataset, kfold_dataset

name = 'adult.data'
X, Y, X_headers, Y_headers = transform_dataset(name, method='onehot-compare',
negations=False, labels='binary')
datasets = kfold_dataset(X, Y, shuffle=1)
X_train, X_test, Y_train, Y_test = datasets[0]
train_set = torch.utils.data.TensorDataset(torch.Tensor(X_train.to_numpy()),
torch.Tensor(Y_train))

Creating and training the model

from r2ntab import R2NTab

model = R2Ntab(len(X_headers), 10, 1)
model.fit(train_set, epochs=1000)
Y_pred = model.predict(X_test)

Extracting the results

rules = model.extract_rules(X_headers, print_rules=True)
print(f'AUC: {model.score(Y_pred, Y_test, metric="auc")}')
print(f'# Rules: {len(rules)}')
print(f'# Conditions: {sum(map(len, rules))}')

Contact

For any questions or problems, please open an issue here on GitHub.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

r2ntab-1.0.1.tar.gz (13.5 kB view details)

Uploaded Source

Built Distribution

r2ntab-1.0.1-py3-none-any.whl (15.2 kB view details)

Uploaded Python 3

File details

Details for the file r2ntab-1.0.1.tar.gz.

File metadata

  • Download URL: r2ntab-1.0.1.tar.gz
  • Upload date:
  • Size: 13.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.10

File hashes

Hashes for r2ntab-1.0.1.tar.gz
Algorithm Hash digest
SHA256 75da51408436031621d620c3e66b4f586a71bd467a8f0c92d36846312dca7c65
MD5 366830c51633e0a31a8070be9af680b1
BLAKE2b-256 1e4273d1a2de072b6f1eac32c10bd6aee0bb413029af54f39f3f015f0bfdce0a

See more details on using hashes here.

File details

Details for the file r2ntab-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: r2ntab-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 15.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.10

File hashes

Hashes for r2ntab-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 4d57c07074b583ed351403a46615986df5ad155070fb998d722d86a3462803c9
MD5 320828104a102ed7c49fbb21b3f0e088
BLAKE2b-256 6235f24482f44e7943fefa45aa39e3b745b16a1178d6dc1bbe7cd661713a8d36

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page