Skip to main content

Hybrid VAE-GAN with LightGBM for class-imbalanced classification

Project description

vaganboostktf: VAE-GAN Boost by TF

=============================================

vaganboostktf

=============================================

Python Version License: MIT

VaganBoostKFT

VaganBoostKFT is a hybrid machine learning package that integrates generative modeling (using CVAE and CGAN) with an advanced LightGBM classifier pipeline. It provides robust data preprocessing, custom sampling strategies for imbalanced data, automated hyperparameter tuning (including dimensionality reduction via PCA, LDA, or TruncatedSVD), and built-in visualization of model evaluation metrics (confusion matrices, ROC curves, and precision-recall curves).

Features

  • 🧬 Hybrid architecture combining generative and discriminative models
  • ⚖️ Effective handling of class imbalance through synthetic data generation
  • 🔄 Iterative training process with automatic model refinement
  • 📊 Comprehensive evaluation metrics and visualizations
  • 💾 Model persistence and reproducibility features
  • 🖥️ Command-line interface for easy operation

Installation

Install the required dependencies:

pip install dill
pip install dask[dataframe]
pip install umap-learn

Additional dependencies (if not already installed) include:

  • scikit-learn
  • imbalanced-learn
  • lightgbm
  • tensorflow
  • seaborn
  • matplotlib
  • joblib
pip install vaganboostktf

For development installation:

git clone https://github.com/yourusername/vaganboostktf.git
cd vaganboostktf
pip install -e .

Modules

  • data_preprocessor.py: Provides consistent data preprocessing (scaling, handling missing values, and encoding).
  • trainer.py: Orchestrates the hybrid training workflow combining generative models (CVAE, CGAN) and the LightGBM classifier.
  • lgbm_tuner.py: Implements hyperparameter tuning for the advanced LightGBM pipeline.
  • lgbm_classifier.py: Contains the full LightGBM classifier pipeline that integrates preprocessing, feature selection, dimensionality reduction, SMOTE balancing (with custom sampling strategies), and hyperparameter tuning.
  • utils.py: Provides utility functions for visualization (confusion matrix, ROC curves, precision-recall curves) and helper classes like DecompositionSwitcher.

Usage Example

Below is a sample script demonstrating how to use VaganBoostKFT:

import pandas as pd
import numpy as np
from vaganboostktf.data_preprocessor import DataPreprocessor
from vaganboostktf.trainer import HybridModelTrainer
from vaganboostktf.lgbm_tuner import LightGBMTuner
from vaganboostktf.utils import plot_confusion_matrix, plot_roc_curves, plot_pr_curves

# ===========================
# 1. Load and Prepare Data
# ===========================
df = pd.read_csv("input.csv")

# Identify features and target
feature_columns = [col for col in df.columns if col != "label"]
target_column = "label"

# Initialize data preprocessor
preprocessor = DataPreprocessor()

# Preprocess data (handling missing values, scaling, encoding)
X_train_scaled, X_test_scaled, y_train, y_test = preprocessor.prepare_data(
    df, feature_columns, target_column
)

# ===========================
# 2. Train Hybrid Model (CVAE, CGAN + LGBM)
# ===========================
trainer = HybridModelTrainer(config={
    'num_classes': 4,
    'cvae_params': {
        'input_dim': 25,
        'latent_dim': 10,
        'num_classes': 4,
        'learning_rate': 0.01
    },
    'cgan_params': {
        'input_dim': 25,
        'latent_dim': 10,
        'num_classes': 4,
        'generator_lr': 0.0002,
        'discriminator_lr': 0.0002
    },
	'input_path': 'input.csv',
    'model_dir': 'trained_models',
    'cvae_epochs': 100,
    'cgan_epochs': 100,
    'lgbm_iterations': 100,
    'samples_per_class': 50
})

# Run hybrid training (Generative + LGBM)
trainer.training_loop(X_train_scaled, y_train, X_test_scaled, y_test, iterations=5)
print("\nHybrid training completed! Models saved in 'trained_models/'")

# ===========================
# 3. Load and Evaluate LightGBM Model
# ===========================
lgbm_tuner = LightGBMTuner(input_path="input.csv", output_path="trained_models")

# Train the LightGBM model (already tuned within `lgbm_classifier`)
lgbm_tuner.tune()

# Predict on test data
y_pred = lgbm_tuner.predict(X_test_scaled)
y_proba = lgbm_tuner.predict_proba(X_test_scaled)

# ===========================
# 4. Visualize Results
# ===========================
class_names = [str(i) for i in np.unique(y_test)]

# Plot Confusion Matrix
conf_matrix_fig = plot_confusion_matrix(y_test, y_pred, class_names, normalize=True)
conf_matrix_fig.savefig("trained_models/confusion_matrix.png")

# Plot ROC Curves
roc_curve_fig = plot_roc_curves(y_test, y_proba, class_names)
roc_curve_fig.savefig("trained_models/roc_curve.png")

# Plot Precision-Recall Curves
pr_curve_fig = plot_pr_curves(y_test, y_proba, class_names)
pr_curve_fig.savefig("trained_models/pr_curve.png")

print("\nEvaluation completed! Check 'trained_models/' for plots.")

Architecture

graph TD
    A["Raw Data (CSV)"] --> B["DataPreprocessor"]
    B --> C["Preprocessed Data"]
    C --> D["CVAE"]
    C --> E["CGAN"]
    D --> F["Synthetic Data (CVAE)"]
    E --> G["Synthetic Data (CGAN)"]
    F --> H["Combined Real & Synthetic Data"]
    G --> H
    H --> I["LightGBM Classifier Pipeline"]
    I --> J["Evaluation (Confusion Matrix, ROC, PR Curves)"]
    J --> K["Best Models Saved"]

Key Components

  • Conditional VAE: Generates class-conditioned synthetic samples
  • Conditional GAN: Produces additional class-specific synthetic data
  • LightGBM Tuner: Optimized gradient boosting with automated hyperparameter search
  • Hybrid Trainer: Orchestrates iterative training process

Additional Information

  • Hybrid Workflow: The training loop in trainer.py first trains generative models (CVAE and CGAN) to create synthetic data, which is then combined with real data to train a robust LightGBM classifier.
  • Custom Sampling Strategies: lgbm_classifier.py integrates a function to generate sampling strategies for SMOTE to address severe class imbalance.
  • Visualization: Evaluation plots are generated and saved in the output directory to help assess model performance.

Configuration

Default parameters can be modified through:

  • Command-line arguments
  • JSON configuration files
  • Python API parameters

License

This project is licensed under the MIT License - see the LICENSE file for details.

=============================================

ALI BAVARCHIEE

=============================================

| https://github.com/AliBavarchee/ |

| https://www.linkedin.com/in/ali-bavarchee-qip/ |

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vaganboostktf-1.1.0.tar.gz (21.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

vaganboostktf-1.1.0-py3-none-any.whl (22.7 kB view details)

Uploaded Python 3

File details

Details for the file vaganboostktf-1.1.0.tar.gz.

File metadata

  • Download URL: vaganboostktf-1.1.0.tar.gz
  • Upload date:
  • Size: 21.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.12.7

File hashes

Hashes for vaganboostktf-1.1.0.tar.gz
Algorithm Hash digest
SHA256 a79ea4f34c20822ec7d0330c07099bb5bc8f27d09368992057482688721de76f
MD5 d128b4e606e6236d42de4e7e01ccd23f
BLAKE2b-256 8c1cd7b52c59ab7b092c8d335167ddff7615cb0b7cfc6a9755c435273f9629ad

See more details on using hashes here.

File details

Details for the file vaganboostktf-1.1.0-py3-none-any.whl.

File metadata

  • Download URL: vaganboostktf-1.1.0-py3-none-any.whl
  • Upload date:
  • Size: 22.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.12.7

File hashes

Hashes for vaganboostktf-1.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 09c67a39a4cdc336dd57027b6e5449e411a8f32a3b410367e63112e3e7da42dd
MD5 863428c76f8b236ccbd3028c890fb435
BLAKE2b-256 f1622faba26c01c56836d34c684a658137a8194126a20d996495bed022b8a258

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page