Skip to main content

OBAN Classifier: A Skorch-based flexible neural network for binary and multiclass classification

Project description

OBAN Classifier

Oban Classifier is a flexible neural network-based classifier built on top of PyTorch and Skorch. It supports both binary and multiclass classification, and allows users to define parameters such as the number of units, activation function, dropout rate, and more.

Features

  • Supports binary and multiclass classification.
  • Allows user-defined parameters for hidden units, activation functions, dropout, and more.
  • Built using Skorch and PyTorch for easy integration with scikit-learn pipelines.
  • Provides detailed performance metrics including accuracy, precision, recall, F1-score, and confusion matrix.

Installation

You can install the library via pip after publishing it on PyPI:

pip install oban_classifier


### Usage Example

```python

from oban_classifier import oban_classifier, post_classification_analysis, plot_lime_importance
from sklearn.datasets import load_breast_cancer
from sklearn.preprocessing import StandardScaler
import pandas as pd


# Load the Breast Cancer dataset

data = load_breast_cancer()

X = pd.DataFrame(data.data, columns=data.feature_names)

y = pd.Series(data.target)

# Train and evaluate the model

netv, X_test, y_test = oban_classifier(X, y, num_units=128, max_epochs=80, lr=0.001)

# Convert X_test to DataFrame with feature names

X_test_df = pd.DataFrame(X_test, columns=X.columns)

# Predict probabilities

# Ensure that X_test_df is passed as a NumPy array

y_proba = netv.predict_proba(X_test_df.to_numpy())


# Perform post-classification analysis

post_classification_analysis(X_test_df, y_test, y_proba, threshold=0.5)

# Explain predictions using LIME with correct feature names

plot_lime_importance(netv, X_test_df, y_test, feature_names=X.columns)


# Assume this is a new data point (must have the same number of features as the original training data)

new_data = pd.DataFrame([[15.0, 20.0, 85.0, 60.0, 0.5, 1.5, 3.0, 0.02, 0.2, 0.3,
                          0.1, 25.0, 50.0, 150.0, 100.0, 0.1, 0.5, 2.5, 0.01, 0.1,
                          15.0, 20.0, 85.0, 60.0, 0.5, 1.5, 3.0, 0.02, 0.2, 0.3]], 
                        columns=X.columns)

# Normalize the new data using the same scaler as before

scaler = StandardScaler()
scaler.fit(X)  # Make sure to fit the scaler using the original training data
new_data_scaled = scaler.transform(new_data)

# Predict the class for the new data

predicted_class = netv.predict(new_data_scaled)
print(f"Predicted class: {predicted_class}")

# If you want to predict probabilities

predicted_probabilities = netv.predict_proba(new_data_scaled)
print(f"Predicted probabilities: {predicted_probabilities}")

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

oban_classifier-0.1.19.7.tar.gz (5.2 kB view details)

Uploaded Source

Built Distribution

oban_classifier-0.1.19.7-py3-none-any.whl (5.8 kB view details)

Uploaded Python 3

File details

Details for the file oban_classifier-0.1.19.7.tar.gz.

File metadata

  • Download URL: oban_classifier-0.1.19.7.tar.gz
  • Upload date:
  • Size: 5.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.19

File hashes

Hashes for oban_classifier-0.1.19.7.tar.gz
Algorithm Hash digest
SHA256 3377376b04167525db5b8a8b1aeaf3117b0b88cf7124e59b363368968643f8b9
MD5 cdbf507c2a9247b8d414a154575dffb0
BLAKE2b-256 45fbea9e142b3300950cff9280c49273929ce6ae410ffd355276ada3cbc4c4bb

See more details on using hashes here.

File details

Details for the file oban_classifier-0.1.19.7-py3-none-any.whl.

File metadata

File hashes

Hashes for oban_classifier-0.1.19.7-py3-none-any.whl
Algorithm Hash digest
SHA256 31070b1f30351bb4ef57872b6a916ffc3e96fac7392597e06dcd3210736a8ce3
MD5 747750c4439a2e5051f475fe76375b14
BLAKE2b-256 b4dca1ca41561242d10e8c59f0ec7dceb400c31985ab9548bc823ac207745bbc

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page