Skip to main content

Interpretable machine learning model for binary classification combining deep learning and rule learning

Project description

Developed by M.J. van der Zwart as MSc thesis project (c) 2023

Preparing data

import torch

from r2ntab import transform_dataset, kfold_dataset

name = 'adult.data'
X, Y, X_headers, Y_headers = transform_dataset(name, method='onehot-compare',
negations=False, labels='binary')
datasets = kfold_dataset(X, Y, shuffle=1)
X_train, X_test, Y_train, Y_test = datasets[0]
train_set = torch.utils.data.TensorDataset(torch.Tensor(X_train.to_numpy()),
torch.Tensor(Y_train))

Creating and training the model

from r2ntab import R2NTab

model = R2Ntab(len(X_headers), 10, 1)
model.fit(train_set, epochs=1000)
Y_pred = model.predict(X_test)

Extracting the results

rules = model.extract_rules(X_headers, print_rules=True)
print(f'AUC: {model.score(Y_pred, Y_test, metric="auc")}')
print(f'# Rules: {len(rules)}')
print(f'# Conditions: {sum(map(len, rules))}')

Contact

For any questions of problems, please open an issue here on GitHub.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

r2ntab-1.0.0.tar.gz (13.4 kB view details)

Uploaded Source

Built Distribution

r2ntab-1.0.0-py3-none-any.whl (15.2 kB view details)

Uploaded Python 3

File details

Details for the file r2ntab-1.0.0.tar.gz.

File metadata

  • Download URL: r2ntab-1.0.0.tar.gz
  • Upload date:
  • Size: 13.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.10

File hashes

Hashes for r2ntab-1.0.0.tar.gz
Algorithm Hash digest
SHA256 0bf8fb1acb52e7ee90e308a08faa09e7d75f0a29cfc36d6d201ec4eb20260013
MD5 711e50d79c89faa79bba56476774cc1b
BLAKE2b-256 afd95bf7a36537edec1244eb82ccfa894d41b2ca556fa4d733feacdc3215091b

See more details on using hashes here.

File details

Details for the file r2ntab-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: r2ntab-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 15.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.10

File hashes

Hashes for r2ntab-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b5d256d5d0f698db9974b4bff41de11eff115e0da277d0f04d6ad1555975724a
MD5 dae0f3a409034884377e953431d8da4e
BLAKE2b-256 8783a90a185162b0b403ce7dc4982d09083102e28ec19b92cec78961547a7394

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page