Skip to main content

WHFDL for data classification

Project description

CWHFDL

A package for data classification with WHFDL.

This package is part of "WHFDL: an explainable method based on World Hyper-heuristic and Fuzzy Deep Learning approaches for gastric cancer detection using metabolomics data" article's experiment.

Example:

import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader, TensorDataset, random_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report

import CWHFDL as cw

cw.set_seed(42)

df = pd.read_csv('CGMAIN1.csv')
X = df.drop(columns='state').values
y = df['state'].values

scaler = StandardScaler()
X = scaler.fit_transform(X)

X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.long)

dataset = TensorDataset(X_tensor, y_tensor)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

model = cw.FDNN(
    in_features=X.shape[1],
    hidden_dim=128,
    num_memberships=3,
    num_classes=len(np.unique(y)),
    dropout_rate=0.1
)

cw.initialize_fuzzy_layer(model.fuzzy, train_loader)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()
for epoch in range(100):
    total_loss = 0
    for x_batch, y_batch in train_loader:
        optimizer.zero_grad()
        outputs = model(x_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch + 1}: Loss = {total_loss:.4f}")

model.eval()
all_preds, all_probs, all_labels = [], [], []
with torch.no_grad():
    for x_batch, y_batch in test_loader:
        outputs = model(x_batch)
        probs = torch.softmax(outputs, dim=1)
        preds = torch.argmax(probs, dim=1)
        all_preds.extend(preds.numpy())
        all_probs.extend(probs[:, 1].numpy())
        all_labels.extend(y_batch.numpy())

print(
    classification_report(
        all_labels,
        all_preds,
        target_names=[
            'Negative',
            'Positive']))

conf_matrix = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
            xticklabels=['Predicted Negative', 'Predicted Positive'],
            yticklabels=['Actual Negative', 'Actual Positive'])
plt.title('WHFDL  Confusion Matrix', fontsize=16, fontweight='bold')
plt.xlabel('Predicted Label', fontsize=14)
plt.ylabel('True Label', fontsize=14)
plt.show()

fpr, tpr, thresholds = roc_curve(all_labels, all_probs)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='blue', lw=2, label=f'ROC Curve (AUC = {roc_auc:.4f}')
plt.plot([0, 1], [0, 1], color='red', linestyle='--', label='Random Guessing')
plt.xlabel('False Positive Rate', fontsize=14)
plt.ylabel('True Positive Rate', fontsize=14)
plt.title('WHFDL (ROC)', fontsize=16, fontweight='bold')
plt.legend(loc="lower right")
plt.grid(True)
plt.show()

You can find the "CGMAIN1.csv" here.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cwhfdl-0.0.2-py3-none-any.whl (4.7 kB view details)

Uploaded Python 3

File details

Details for the file cwhfdl-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: cwhfdl-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 4.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.5

File hashes

Hashes for cwhfdl-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 b9209118510343b1501bbc8234a49b362fdc7e23f06546cb912204fada483063
MD5 a92225f55453dd3d32b54e65487e67b2
BLAKE2b-256 657ddf36c6797c423e8e81d025a3e0e5751d810f127c32124881879f34c62894

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page