OBAN Classifier: A Skorch-based flexible neural network for binary and multiclass classification
Project description
OBAN Classifier
Oban Classifier is a flexible neural network-based classifier built on top of PyTorch and Skorch. It supports both binary and multiclass classification, and allows users to define parameters such as the number of units, activation function, dropout rate, and more.
Features
- Supports binary and multiclass classification.
- Allows user-defined parameters for hidden units, activation functions, dropout, and more.
- Built using Skorch and PyTorch for easy integration with scikit-learn pipelines.
- Provides detailed performance metrics including accuracy, precision, recall, F1-score, and confusion matrix.
Installation
You can install the library via pip after publishing it on PyPI:
pip install oban_classifier
### Usage Example
```python
from oban_classifier import oban_classifier, post_classification_analysis, plot_lime_importance
from sklearn.datasets import load_breast_cancer
from sklearn.preprocessing import StandardScaler
import pandas as pd
# Load the Breast Cancer dataset
data = load_breast_cancer()
X = pd.DataFrame(data.data, columns=data.feature_names)
y = pd.Series(data.target)
# Train and evaluate the model
netv, X_test, y_test = oban_classifier(X, y, num_units=128, max_epochs=80, lr=0.001)
# Convert X_test to DataFrame with feature names
X_test_df = pd.DataFrame(X_test, columns=X.columns)
# Predict probabilities
# Ensure that X_test_df is passed as a NumPy array
y_proba = netv.predict_proba(X_test_df.to_numpy())
# Perform post-classification analysis
post_classification_analysis(X_test_df, y_test, y_proba, threshold=0.5)
# Explain predictions using LIME with correct feature names
plot_lime_importance(netv, X_test_df, y_test, feature_names=X.columns)
# Assume this is a new data point (must have the same number of features as the original training data)
new_data = pd.DataFrame([[15.0, 20.0, 85.0, 60.0, 0.5, 1.5, 3.0, 0.02, 0.2, 0.3,
0.1, 25.0, 50.0, 150.0, 100.0, 0.1, 0.5, 2.5, 0.01, 0.1,
15.0, 20.0, 85.0, 60.0, 0.5, 1.5, 3.0, 0.02, 0.2, 0.3]],
columns=X.columns)
# Normalize the new data using the same scaler as before
scaler = StandardScaler()
scaler.fit(X) # Make sure to fit the scaler using the original training data
new_data_scaled = scaler.transform(new_data)
# Predict the class for the new data
predicted_class = netv.predict(new_data_scaled)
print(f"Predicted class: {predicted_class}")
# If you want to predict probabilities
predicted_probabilities = netv.predict_proba(new_data_scaled)
print(f"Predicted probabilities: {predicted_probabilities}")
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file oban_classifier-0.1.19.7.tar.gz
.
File metadata
- Download URL: oban_classifier-0.1.19.7.tar.gz
- Upload date:
- Size: 5.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.19
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3377376b04167525db5b8a8b1aeaf3117b0b88cf7124e59b363368968643f8b9 |
|
MD5 | cdbf507c2a9247b8d414a154575dffb0 |
|
BLAKE2b-256 | 45fbea9e142b3300950cff9280c49273929ce6ae410ffd355276ada3cbc4c4bb |
File details
Details for the file oban_classifier-0.1.19.7-py3-none-any.whl
.
File metadata
- Download URL: oban_classifier-0.1.19.7-py3-none-any.whl
- Upload date:
- Size: 5.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.19
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 31070b1f30351bb4ef57872b6a916ffc3e96fac7392597e06dcd3210736a8ce3 |
|
MD5 | 747750c4439a2e5051f475fe76375b14 |
|
BLAKE2b-256 | b4dca1ca41561242d10e8c59f0ec7dceb400c31985ab9548bc823ac207745bbc |