Hybrid VAE-GAN with LightGBM for class-imbalanced regression task
Project description
striXhooT: Hybrid Generative Model for Regression Tasks
=============================================
striXhooT
striXhooT is a hybrid machine learning package that integrates generative models (CVAE, CGAN) with LightGBM for enhanced predictive modeling. The package is designed for structured datasets where generating synthetic data can improve regression performance.
Features
- Train Conditional Variational Autoencoder (CVAE) and Conditional Generative Adversarial Network (CGAN) models.
- Use generative models to augment training data for better regression performance.
- Implement an optimized LightGBM model with optional hyperparameter tuning.
- Supports PCA and SVD for dimensionality reduction.
- Provides end-to-end orchestration of model training and evaluation.
Installation
To install striXhoot, clone the repository and install dependencies:
# Clone the repository
git clone https://github.com/AliBavarchee/strixhoot.git
cd strixhoot
# Install dependencies
pip install -r requirements.txt
Usage
Below is an example script to train the hybrid model using striXhoot:
import os
import argparse
from strixhoot.trainer import main as train_model
# Define paths
input_path = "input.csv" # Path to the dataset
output_path = "output_results" # Directory to save all models and results
# Define training configuration
config = {
"input_path": input_path,
"output_path": output_path,
"seed": 42, # Set a random seed for reproducibility
"gen_model": "both", # Train both CVAE and CGAN models
"train_generative": True, # Enable training of generative models
"tune_lgbm": False, # Set to True if hyperparameter tuning is needed
"dim_reducer": "pca", # Choose between "pca" or "svd"
"n_iter": 20, # Number of iterations for hyperparameter tuning (only if tune_lgbm=True)
}
# Ensure the output directory exists
os.makedirs(output_path, exist_ok=True)
# Convert dictionary to argparse namespace
args = argparse.Namespace(**config)
# Train the models using the defined configuration
train_model(args)
Configuration Options
| Parameter | Description |
|---|---|
input_path |
Path to the dataset (CSV format) |
output_path |
Directory where results and models will be saved |
seed |
Random seed for reproducibility |
gen_model |
Choose between cvae, cgan, or both |
train_generative |
Boolean flag to enable generative model training |
tune_lgbm |
Boolean flag to enable LightGBM hyperparameter tuning |
dim_reducer |
Choose dimensionality reduction method: pca or svd |
n_iter |
Number of iterations for LightGBM hyperparameter tuning (only if tune_lgbm=True) |
Results and Outputs
After running the script, the following artifacts will be saved in the output_results/ directory:
- Trained CVAE and CGAN models
- Trained LightGBM regressor
- Performance metrics and visualizations (such as
true_vs_pred.png)
Contributor(s)
- Ali Bavarchee (ali.bavarchee@gmail.com)
Configuration
Default parameters can be modified through:
- Command-line arguments
- JSON configuration files
- Python API parameters
License
This project is licensed under the MIT License - see the LICENSE file for details.
=============================================
| https://github.com/AliBavarchee/ |
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file strixhoot-0.0.9.tar.gz.
File metadata
- Download URL: strixhoot-0.0.9.tar.gz
- Upload date:
- Size: 17.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.12.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
22b4a23231b229fb42b2cfd2afb8a4a0d9c2f529e94b80f8a74e4eccc76da1e9
|
|
| MD5 |
958cf04e1562139c1c6ae6b2a370f02f
|
|
| BLAKE2b-256 |
afe980f00254cd2dae69b6ff21e35bd29176a2227365474bd66b03fd8d9f486a
|
File details
Details for the file strixhoot-0.0.9-py3-none-any.whl.
File metadata
- Download URL: strixhoot-0.0.9-py3-none-any.whl
- Upload date:
- Size: 21.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.12.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
856ec1fbdea38f8e7860e41f37f0b70394dc8791a3e90bf77b94cf3cd01cb766
|
|
| MD5 |
45f2e46eb4166aec9a57a0fddd844bca
|
|
| BLAKE2b-256 |
0efa2754517252cdf6b1ecb4f1bbc5eb1e62415363c202c421e6cf73a2e82f4f
|