Skip to main content

OBAN Classifier: A Skorch-based flexible neural network for binary and multiclass classification

Project description

OBAN Classifier

Oban Classifier is a flexible neural network-based classifier built on top of PyTorch and Skorch. It supports both binary and multiclass classification, and allows users to define parameters such as the number of units, activation function, dropout rate, and more.

Features

  • Supports binary and multiclass classification.
  • Allows user-defined parameters for hidden units, activation functions, dropout, and more.
  • Built using Skorch and PyTorch for easy integration with scikit-learn pipelines.
  • Provides detailed performance metrics including accuracy, precision, recall, F1-score, and confusion matrix.

Installation

You can install the library via pip after publishing it on PyPI:

pip install oban_classifier


### Usage Example

```python

from oban_classifier import oban_classifier, post_classification_analysis, plot_lime_importance
from sklearn.datasets import load_breast_cancer
from sklearn.preprocessing import StandardScaler
import pandas as pd


# Load the Breast Cancer dataset

data = load_breast_cancer()

X = pd.DataFrame(data.data, columns=data.feature_names)

y = pd.Series(data.target)

# Train and evaluate the model with num_classes explicitly defined (2 for binary classification)
netv, X_test, y_test = oban_classifier(X, y, num_units=128, num_classes=2, max_epochs=80, lr=0.001)

# Convert X_test to DataFrame with feature names
X_test_df = pd.DataFrame(X_test, columns=X.columns)

# Predict probabilities
y_proba = netv.predict_proba(X_test_df.to_numpy())

# Perform post-classification analysis
post_classification_analysis(X_test_df, y_test, y_proba, threshold=0.5)

# Explain predictions using LIME with correct feature names
plot_lime_importance(netv, X_test_df, y_test, feature_names=X.columns)

# Assume this is a new data point (must have the same number of features as the original training data)
new_data = pd.DataFrame([[15.0, 20.0, 85.0, 60.0, 0.5, 1.5, 3.0, 0.02, 0.2, 0.3,
                          0.1, 25.0, 50.0, 150.0, 100.0, 0.1, 0.5, 2.5, 0.01, 0.1,
                          15.0, 20.0, 85.0, 60.0, 0.5, 1.5, 3.0, 0.02, 0.2, 0.3]], 
                        columns=X.columns)

# Normalize the new data using the same scaler as before
scaler = StandardScaler()
scaler.fit(X)  # Fit the scaler using the original training data
new_data_scaled = scaler.transform(new_data)

# Predict the class for the new data
predicted_class = netv.predict(new_data_scaled)

print(f"Predicted class: {predicted_class}")

# If you want to predict probabilities for the new data
predicted_probabilities = netv.predict_proba(new_data_scaled)

print(f"Predicted probabilities: {predicted_probabilities}")




#### oban_classifier Parameters

X (pd.DataFrame): The feature matrix. Should be a Pandas DataFrame where each row is an instance and each column is a feature.

y (pd.Series): The target variable. Should be a Pandas Series where each value corresponds to the target class of a given row in X.

num_units (int, optional, default=128): The number of hidden units in the dense layers of the neural network.

num_classes (int, required): The number of classes for classification. For binary classification, set this to 2 . For multiclass problems, set this to the total number of classes.

nonlin (torch.nn.Module, optional, default=nn.ReLU()): The non-linear activation function to apply after each dense layer. Default is ReLU, but can be changed to other functions like nn.Sigmoid() or nn.Tanh().

dropout_rate (float, optional, default=0.5): The dropout rate applied to the layers to prevent overfitting. Should be between 0 and 1.

max_epochs (int, optional, default=10): The maximum number of epochs to train the model.

lr (float, optional, default=0.01): The learning rate for the optimizer.

test_size (float, optional, default=0.2): The proportion of the dataset to be used for testing. Should be between 0 and 1.

random_state (int, optional, default=42): The seed for the random number generator to ensure reproducible results during dataset splitting.


#### post_classification_analysis Parameters

X (pd.DataFrame): The feature matrix used during testing.

y_true (pd.Series): The true class labels for the test set.

y_proba (np.ndarray): The predicted probabilities for each class.

threshold (float, optional, default=0.5): The decision threshold used for binary classification. Predictions with probabilities greater than or equal to the threshold are classified as 1, otherwise as 0. This parameter is ignored in multiclass classification.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

oban_classifier-0.1.19.9.tar.gz (5.6 kB view details)

Uploaded Source

Built Distribution

oban_classifier-0.1.19.9-py3-none-any.whl (6.1 kB view details)

Uploaded Python 3

File details

Details for the file oban_classifier-0.1.19.9.tar.gz.

File metadata

  • Download URL: oban_classifier-0.1.19.9.tar.gz
  • Upload date:
  • Size: 5.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.19

File hashes

Hashes for oban_classifier-0.1.19.9.tar.gz
Algorithm Hash digest
SHA256 7273b85eb9985e0006accc4df9504973e95d00001e1db9fb2bda16146124399f
MD5 c0d4a5d9d2d207c2fab15c1f16485d7e
BLAKE2b-256 ed5cacf4252d153be8c3b62bf3e7e4132ec512539671e68a2a64edd5b4715356

See more details on using hashes here.

File details

Details for the file oban_classifier-0.1.19.9-py3-none-any.whl.

File metadata

File hashes

Hashes for oban_classifier-0.1.19.9-py3-none-any.whl
Algorithm Hash digest
SHA256 351ddcb3a9d0bbdaed2d344488ec66e3066352bb84276d0ec5f77bba7b955d70
MD5 481700e62ab03dad187804737181f2e5
BLAKE2b-256 0c9bc88b35f09f3629347f3712c61fbf83c1fe6bd554c341c0f0a230b198bb9e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page