Hybrid VAE-GAN with LightGBM for class-imbalanced classification
Project description
vaganboostktf: VAE-GAN Boost by TF
=============================================
Hybrid generative-classification framework combining Conditional VAE, Conditional GAN, and LightGBM for handling class-imbalanced classification tasks.
Features
- 🧬 Hybrid architecture combining generative and discriminative models
- ⚖️ Effective handling of class imbalance through synthetic data generation
- 🔄 Iterative training process with automatic model refinement
- 📊 Comprehensive evaluation metrics and visualizations
- 💾 Model persistence and reproducibility features
- 🖥️ Command-line interface for easy operation
Installation
pip install vaganboostktf
For development installation:
git clone https://github.com/yourusername/vaganboostktf.git
cd vaganboostktf
pip install -e .
Usage
Basic Python API
import math
import numpy as np
import pandas as pd
import joblib
from pathlib import Path
import dill
from vaganboostktf.data_preprocessor import DataPreprocessor
from vaganboostktf.cgan import CGAN
from vaganboostktf.lgbm_tuner import LightGBMTuner
from vaganboostktf.trainer import HybridModelTrainer
from vaganboostktf.utils import plot_confusion_matrix, plot_roc_curves, plot_pr_curves
# Load and prepare data
df = pd.read_csv("Input.csv")
# Identify features and target
feature_columns = [f"ClE{i}" for i in range(1, 26)] # 25 features
target_column = "label"
# Initialize data preprocessor
preprocessor = DataPreprocessor()
X_train, X_test, y_train, y_test = preprocessor.prepare_data(
df,
feature_columns=[f"ClE{i}" for i in range(1, 26)],
target_column="label"
)
# Prepare scaled datasets
X_train_scaled, X_test_scaled, y_train, y_test = preprocessor.prepare_data(
df,
feature_columns=feature_columns,
target_column=target_column
)
# Initialize and train hybrid model
trainer = HybridModelTrainer(config={
'num_classes': 4,
'cvae_params': {
'input_dim': 25,
'latent_dim': 8,
'num_classes': 4,
'learning_rate': 0.01
},
'cgan_params': {
'input_dim': 25,
'latent_dim': 8,
'num_classes': 4,
'generator_lr': 0.0002,
'discriminator_lr': 0.0002
},
'model_dir': 'trained_models',
'cvae_epochs': 10,
'cgan_epochs': 10,
'lgbm_iterations': 10,
'samples_per_class': 50
})
trainer.training_loop(
X_train_scaled, y_train,
X_test_scaled, y_test,
iterations=5
)
print("Training completed! Best models saved in 'trained_models' directory")
# Visualization of the metrics and results
# Load trained LightGBM model
lgbm_model = joblib.load("trained_models/lgbm_model.pkl")
# Get predictions
y_pred = lgbm_model.predict(X_test_scaled)
y_proba = lgbm_model.predict_proba(X_test_scaled)
# Define class names
class_names = [str(i) for i in np.unique(y_test)]
# Plot and save confusion matrix
conf_matrix_fig = plot_confusion_matrix(y_test, y_pred, classes=class_names, normalize=True)
conf_matrix_fig.savefig("trained_models/confusion_matrix.png")
# Plot and save ROC curves
roc_curve_fig = plot_roc_curves(y_test, y_proba, classes=class_names)
roc_curve_fig.savefig("trained_models/roc_curves.png")
# Plot and save Precision-Recall curves
pr_curve_fig = plot_pr_curves(y_test, y_proba, classes=class_names)
pr_curve_fig.savefig("trained_models/pr_curves.png")
# Extract feature importances and corresponding feature names
feature_importance = lgbm_model.feature_importances_
feature_names = [f"ClE{i}" for i in range(1, 26)] # 25 features
# Create a DataFrame for sorting
importance_df = pd.DataFrame({'Feature': feature_names, 'Importance': feature_importance})
importance_df = importance_df.sort_values(by='Importance', ascending=False).head(10) # Top 10 features
# Plot feature importance
plt.figure(figsize=(10, 6))
sns.barplot(x='Importance', y='Feature', data=importance_df, palette='Blues_r')
plt.xlabel("Importance Score")
plt.ylabel("Feature")
plt.title("Top 10 Most Important Features")
plt.grid(axis='x', linestyle='--', alpha=0.6)
# Save and show plot
plt.savefig("trained_models/top_10_features.png", bbox_inches='tight')
plt.show()
Architecture
graph TD
A[Input Data] --> B[Data Preprocessing]
B --> C[CVAE Training]
B --> D[CGAN Training]
C --> E[Synthetic Data Generation]
D --> E
E --> F[Data Augmentation]
F --> G[LightGBM Training]
G --> H[Evaluation]
H --> I[Model Persistence]
Key Components
- Conditional VAE: Generates class-conditioned synthetic samples
- Conditional GAN: Produces additional class-specific synthetic data
- LightGBM Tuner: Optimized gradient boosting with automated hyperparameter search
- Hybrid Trainer: Orchestrates iterative training process
Configuration
Default parameters can be modified through:
- Command-line arguments
- JSON configuration files
- Python API parameters
License
This project is licensed under the MIT License - see the LICENSE file for details.
=============================================
| https://github.com/AliBavarchee/ |
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file vaganboostktf-0.9.1.tar.gz.
File metadata
- Download URL: vaganboostktf-0.9.1.tar.gz
- Upload date:
- Size: 18.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.12.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a59a0352af1074d37954724ec834b3e6121d3e0b3139e81333df7bdcf703d30c
|
|
| MD5 |
402dfafa24b22ab67e2561bf20b27669
|
|
| BLAKE2b-256 |
610d0332e9fcd712113a41fe6cc2b7bf1bf340a858caf0c58ce8c81ab8f03b9a
|
File details
Details for the file vaganboostktf-0.9.1-py3-none-any.whl.
File metadata
- Download URL: vaganboostktf-0.9.1-py3-none-any.whl
- Upload date:
- Size: 19.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.12.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6ad99ba11e87a3a552044030601129130d4a037102d95478d941164a2723d351
|
|
| MD5 |
373863284df3f36a5fdc28182bebc13d
|
|
| BLAKE2b-256 |
cf47e7ac11c7a6478383b6b35b6c1cd1e92b351240b7677a1e1a1d2834f2fef1
|