Hybrid VAE-GAN with LightGBM for class-imbalanced classification
Project description
vaganboostktf: VAE-GAN Boost by TF
=============================================
VaganBoostKFT
VaganBoostKFT is a hybrid machine learning package that integrates generative modeling (using CVAE and CGAN) with an advanced LightGBM classifier pipeline. It provides robust data preprocessing, custom sampling strategies for imbalanced data, automated hyperparameter tuning (including dimensionality reduction via PCA, LDA, or TruncatedSVD), and built-in visualization of model evaluation metrics (confusion matrices, ROC curves, and precision-recall curves).
Features
- 🧬 Hybrid architecture combining generative and discriminative models
- ⚖️ Effective handling of class imbalance through synthetic data generation
- 🔄 Iterative training process with automatic model refinement
- 📊 Comprehensive evaluation metrics and visualizations
- 💾 Model persistence and reproducibility features
- 🖥️ Command-line interface for easy operation
Installation
Install the required dependencies:
pip install dill
pip install dask[dataframe]
pip install umap-learn
Additional dependencies (if not already installed) include:
- scikit-learn
- imbalanced-learn
- lightgbm
- tensorflow
- seaborn
- matplotlib
- joblib
pip install vaganboostktf
For development installation:
git clone https://github.com/yourusername/vaganboostktf.git
cd vaganboostktf
pip install -e .
Modules
- data_preprocessor.py: Provides consistent data preprocessing (scaling, handling missing values, and encoding).
- trainer.py: Orchestrates the hybrid training workflow combining generative models (CVAE, CGAN) and the LightGBM classifier.
- lgbm_tuner.py: Implements hyperparameter tuning for the advanced LightGBM pipeline.
- lgbm_classifier.py: Contains the full LightGBM classifier pipeline that integrates preprocessing, feature selection, dimensionality reduction, SMOTE balancing (with custom sampling strategies), and hyperparameter tuning.
- utils.py: Provides utility functions for visualization (confusion matrix, ROC curves, precision-recall curves) and helper classes like
DecompositionSwitcher.
Usage Example
Below is a sample script demonstrating how to use VaganBoostKFT:
import pandas as pd
import numpy as np
from vaganboostktf.data_preprocessor import DataPreprocessor
from vaganboostktf.trainer import HybridModelTrainer
from vaganboostktf.lgbm_tuner import LightGBMTuner
from vaganboostktf.utils import plot_confusion_matrix, plot_roc_curves, plot_pr_curves
# ===========================
# 1. Load and Prepare Data
# ===========================
df = pd.read_csv("input.csv")
# Identify features and target
feature_columns = [col for col in df.columns if col != "label"]
target_column = "label"
# Initialize data preprocessor
preprocessor = DataPreprocessor()
# Preprocess data (handling missing values, scaling, encoding)
X_train_scaled, X_test_scaled, y_train, y_test = preprocessor.prepare_data(
df, feature_columns, target_column
)
# ===========================
# 2. Train Hybrid Model (CVAE, CGAN + LGBM)
# ===========================
trainer = HybridModelTrainer(config={
'num_classes': 4,
'cvae_params': {
'input_dim': 25,
'latent_dim': 10,
'num_classes': 4,
'learning_rate': 0.01
},
'cgan_params': {
'input_dim': 25,
'latent_dim': 10,
'num_classes': 4,
'generator_lr': 0.0002,
'discriminator_lr': 0.0002
},
'input_path': 'input.csv',
'model_dir': 'trained_models',
'cvae_epochs': 100,
'cgan_epochs': 100,
'lgbm_iterations': 100,
'samples_per_class': 50
})
# Run hybrid training (Generative + LGBM)
trainer.training_loop(X_train_scaled, y_train, X_test_scaled, y_test, iterations=5)
print("\nHybrid training completed! Models saved in 'trained_models/'")
# ===========================
# 3. Load and Evaluate LightGBM Model
# ===========================
lgbm_tuner = LightGBMTuner(input_path="input.csv", output_path="trained_models")
# Train the LightGBM model (already tuned within `lgbm_classifier`)
lgbm_tuner.tune()
# Predict on test data
y_pred = lgbm_tuner.predict(X_test_scaled)
y_proba = lgbm_tuner.predict_proba(X_test_scaled)
# ===========================
# 4. Visualize Results
# ===========================
class_names = [str(i) for i in np.unique(y_test)]
# Plot Confusion Matrix
conf_matrix_fig = plot_confusion_matrix(y_test, y_pred, class_names, normalize=True)
conf_matrix_fig.savefig("trained_models/confusion_matrix.png")
# Plot ROC Curves
roc_curve_fig = plot_roc_curves(y_test, y_proba, class_names)
roc_curve_fig.savefig("trained_models/roc_curve.png")
# Plot Precision-Recall Curves
pr_curve_fig = plot_pr_curves(y_test, y_proba, class_names)
pr_curve_fig.savefig("trained_models/pr_curve.png")
print("\nEvaluation completed! Check 'trained_models/' for plots.")
Architecture
graph TD
A["Raw Data (CSV)"] --> B["DataPreprocessor"]
B --> C["Preprocessed Data"]
C --> D["CVAE"]
C --> E["CGAN"]
D --> F["Synthetic Data (CVAE)"]
E --> G["Synthetic Data (CGAN)"]
F --> H["Combined Real & Synthetic Data"]
G --> H
H --> I["LightGBM Classifier Pipeline"]
I --> J["Evaluation (Confusion Matrix, ROC, PR Curves)"]
J --> K["Best Models Saved"]
Key Components
- Conditional VAE: Generates class-conditioned synthetic samples
- Conditional GAN: Produces additional class-specific synthetic data
- LightGBM Tuner: Optimized gradient boosting with automated hyperparameter search
- Hybrid Trainer: Orchestrates iterative training process
Additional Information
- Hybrid Workflow: The training loop in
trainer.pyfirst trains generative models (CVAE and CGAN) to create synthetic data, which is then combined with real data to train a robust LightGBM classifier. - Custom Sampling Strategies:
lgbm_classifier.pyintegrates a function to generate sampling strategies for SMOTE to address severe class imbalance. - Visualization: Evaluation plots are generated and saved in the output directory to help assess model performance.
Configuration
Default parameters can be modified through:
- Command-line arguments
- JSON configuration files
- Python API parameters
License
This project is licensed under the MIT License - see the LICENSE file for details.
=============================================
| https://github.com/AliBavarchee/ |
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file vaganboostktf-1.1.0.tar.gz.
File metadata
- Download URL: vaganboostktf-1.1.0.tar.gz
- Upload date:
- Size: 21.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.12.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a79ea4f34c20822ec7d0330c07099bb5bc8f27d09368992057482688721de76f
|
|
| MD5 |
d128b4e606e6236d42de4e7e01ccd23f
|
|
| BLAKE2b-256 |
8c1cd7b52c59ab7b092c8d335167ddff7615cb0b7cfc6a9755c435273f9629ad
|
File details
Details for the file vaganboostktf-1.1.0-py3-none-any.whl.
File metadata
- Download URL: vaganboostktf-1.1.0-py3-none-any.whl
- Upload date:
- Size: 22.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.12.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
09c67a39a4cdc336dd57027b6e5449e411a8f32a3b410367e63112e3e7da42dd
|
|
| MD5 |
863428c76f8b236ccbd3028c890fb435
|
|
| BLAKE2b-256 |
f1622faba26c01c56836d34c684a658137a8194126a20d996495bed022b8a258
|