Skip to main content

Hybrid VAE-GAN with LightGBM for class-imbalanced classification

Project description

vaganboostktf: VAE-GAN Boost by TF

=============================================

vaganboostktf

=============================================

Python Version License: MIT

Hybrid generative-classification framework combining Conditional VAE, Conditional GAN, and LightGBM for handling class-imbalanced classification tasks.

Features

  • 🧬 Hybrid architecture combining generative and discriminative models
  • ⚖️ Effective handling of class imbalance through synthetic data generation
  • 🔄 Iterative training process with automatic model refinement
  • 📊 Comprehensive evaluation metrics and visualizations
  • 💾 Model persistence and reproducibility features
  • 🖥️ Command-line interface for easy operation

Installation

pip install vaganboostktf

For development installation:

git clone https://github.com/yourusername/vaganboostktf.git
cd vaganboostktf
pip install -e .

Usage

Basic Python API

import math
import numpy as np
import pandas as pd
import joblib
from pathlib import Path
import dill
from vaganboostktf.data_preprocessor import DataPreprocessor
from vaganboostktf.cgan import CGAN
from vaganboostktf.lgbm_tuner import LightGBMTuner
from vaganboostktf.trainer import HybridModelTrainer
from vaganboostktf.utils import plot_confusion_matrix, plot_roc_curves, plot_pr_curves

# Load and prepare data
df = pd.read_csv("Input.csv")

# Identify features and target
feature_columns = [f"ClE{i}" for i in range(1, 26)]  # 25 features
target_column = "label"

# Initialize data preprocessor
preprocessor = DataPreprocessor()
X_train, X_test, y_train, y_test = preprocessor.prepare_data(
    df,
    feature_columns=[f"ClE{i}" for i in range(1, 26)],
    target_column="label"
)

# Prepare scaled datasets
X_train_scaled, X_test_scaled, y_train, y_test = preprocessor.prepare_data(
    df,
    feature_columns=feature_columns,
    target_column=target_column
)

# Initialize and train hybrid model
trainer = HybridModelTrainer(config={
    'num_classes': 4,
    'cvae_params': {
        'input_dim': 25,
        'latent_dim': 8,
        'num_classes': 4,
        'learning_rate': 0.01
    },
    'cgan_params': {
        'input_dim': 25,
        'latent_dim': 8,
        'num_classes': 4,
        'generator_lr': 0.0002,
        'discriminator_lr': 0.0002
    },
    'model_dir': 'trained_models',
    'cvae_epochs': 10,
    'cgan_epochs': 10,
    'lgbm_iterations': 10,
    'samples_per_class': 50
})
trainer.training_loop(
    X_train_scaled, y_train,
    X_test_scaled, y_test,
    iterations=5
)

print("Training completed! Best models saved in 'trained_models' directory")

# Visualization of the metrics and results

# Load trained LightGBM model
lgbm_model = joblib.load("trained_models/lgbm_model.pkl")

# Get predictions
y_pred = lgbm_model.predict(X_test_scaled)
y_proba = lgbm_model.predict_proba(X_test_scaled)

# Define class names
class_names = [str(i) for i in np.unique(y_test)]

# Plot and save confusion matrix
conf_matrix_fig = plot_confusion_matrix(y_test, y_pred, classes=class_names, normalize=True)
conf_matrix_fig.savefig("trained_models/confusion_matrix.png")

# Plot and save ROC curves
roc_curve_fig = plot_roc_curves(y_test, y_proba, classes=class_names)
roc_curve_fig.savefig("trained_models/roc_curves.png")

# Plot and save Precision-Recall curves
pr_curve_fig = plot_pr_curves(y_test, y_proba, classes=class_names)
pr_curve_fig.savefig("trained_models/pr_curves.png")

# Extract feature importances and corresponding feature names
feature_importance = lgbm_model.feature_importances_
feature_names = [f"ClE{i}" for i in range(1, 26)]  # 25 features

# Create a DataFrame for sorting
importance_df = pd.DataFrame({'Feature': feature_names, 'Importance': feature_importance})
importance_df = importance_df.sort_values(by='Importance', ascending=False).head(10)  # Top 10 features

# Plot feature importance
plt.figure(figsize=(10, 6))
sns.barplot(x='Importance', y='Feature', data=importance_df, palette='Blues_r')
plt.xlabel("Importance Score")
plt.ylabel("Feature")
plt.title("Top 10 Most Important Features")
plt.grid(axis='x', linestyle='--', alpha=0.6)

# Save and show plot
plt.savefig("trained_models/top_10_features.png", bbox_inches='tight')
plt.show()

Architecture

graph TD
    A[Input Data] --> B[Data Preprocessing]
    B --> C[CVAE Training]
    B --> D[CGAN Training]
    C --> E[Synthetic Data Generation]
    D --> E
    E --> F[Data Augmentation]
    F --> G[LightGBM Training]
    G --> H[Evaluation]
    H --> I[Model Persistence]

Key Components

  • Conditional VAE: Generates class-conditioned synthetic samples
  • Conditional GAN: Produces additional class-specific synthetic data
  • LightGBM Tuner: Optimized gradient boosting with automated hyperparameter search
  • Hybrid Trainer: Orchestrates iterative training process

Configuration

Default parameters can be modified through:

  • Command-line arguments
  • JSON configuration files
  • Python API parameters

License

This project is licensed under the MIT License - see the LICENSE file for details.

=============================================

ALI BAVARCHIEE

=============================================

| https://github.com/AliBavarchee/ |

| https://www.linkedin.com/in/ali-bavarchee-qip/ |

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vaganboostktf-0.9.1.tar.gz (18.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

vaganboostktf-0.9.1-py3-none-any.whl (19.9 kB view details)

Uploaded Python 3

File details

Details for the file vaganboostktf-0.9.1.tar.gz.

File metadata

  • Download URL: vaganboostktf-0.9.1.tar.gz
  • Upload date:
  • Size: 18.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.12.7

File hashes

Hashes for vaganboostktf-0.9.1.tar.gz
Algorithm Hash digest
SHA256 a59a0352af1074d37954724ec834b3e6121d3e0b3139e81333df7bdcf703d30c
MD5 402dfafa24b22ab67e2561bf20b27669
BLAKE2b-256 610d0332e9fcd712113a41fe6cc2b7bf1bf340a858caf0c58ce8c81ab8f03b9a

See more details on using hashes here.

File details

Details for the file vaganboostktf-0.9.1-py3-none-any.whl.

File metadata

  • Download URL: vaganboostktf-0.9.1-py3-none-any.whl
  • Upload date:
  • Size: 19.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.12.7

File hashes

Hashes for vaganboostktf-0.9.1-py3-none-any.whl
Algorithm Hash digest
SHA256 6ad99ba11e87a3a552044030601129130d4a037102d95478d941164a2723d351
MD5 373863284df3f36a5fdc28182bebc13d
BLAKE2b-256 cf47e7ac11c7a6478383b6b35b6c1cd1e92b351240b7677a1e1a1d2834f2fef1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page