Skip to main content

Hybrid VAE-GAN with LightGBM for class-imbalanced regression task

Project description

striXhooT: Hybrid Generative Model for Regression Tasks

=============================================

StrixHoot

=============================================

Python Version License: MIT

striXhooT

striXhooT is a hybrid machine learning package that integrates generative models (CVAE, CGAN) with LightGBM for enhanced predictive modeling. The package is designed for structured datasets where generating synthetic data can improve regression performance.

Features

  • Train Conditional Variational Autoencoder (CVAE) and Conditional Generative Adversarial Network (CGAN) models.
  • Use generative models to augment training data for better regression performance.
  • Implement an optimized LightGBM model with optional hyperparameter tuning.
  • Supports PCA and SVD for dimensionality reduction.
  • Provides end-to-end orchestration of model training and evaluation.

Installation

To install striXhoot, clone the repository and install dependencies:

# Clone the repository
git clone https://github.com/AliBavarchee/strixhoot.git
cd strixhoot

# Install dependencies
pip install -r requirements.txt

Usage

Below is an example script to train the hybrid model using striXhoot:

import os
import argparse
from strixhoot.trainer import main as train_model

# Define paths
input_path = "input.csv"  # Path to the dataset
output_path = "output_results"  # Directory to save all models and results

# Define training configuration
config = {
    "input_path": input_path,
    "output_path": output_path,
    "seed": 42,  # Set a random seed for reproducibility
    "gen_model": "both",  # Train both CVAE and CGAN models
    "train_generative": True,  # Enable training of generative models
    "tune_lgbm": False,  # Set to True if hyperparameter tuning is needed
    "dim_reducer": "pca",  # Choose between "pca" or "svd"
    "n_iter": 20,  # Number of iterations for hyperparameter tuning (only if tune_lgbm=True)
}

# Ensure the output directory exists
os.makedirs(output_path, exist_ok=True)

# Convert dictionary to argparse namespace
args = argparse.Namespace(**config)

# Train the models using the defined configuration
train_model(args)

Configuration Options

Parameter Description
input_path Path to the dataset (CSV format)
output_path Directory where results and models will be saved
seed Random seed for reproducibility
gen_model Choose between cvae, cgan, or both
train_generative Boolean flag to enable generative model training
tune_lgbm Boolean flag to enable LightGBM hyperparameter tuning
dim_reducer Choose dimensionality reduction method: pca or svd
n_iter Number of iterations for LightGBM hyperparameter tuning (only if tune_lgbm=True)

Results and Outputs

After running the script, the following artifacts will be saved in the output_results/ directory:

  • Trained CVAE and CGAN models
  • Trained LightGBM regressor
  • Performance metrics and visualizations (such as true_vs_pred.png)

Contributor(s)

Configuration

Default parameters can be modified through:

  • Command-line arguments
  • JSON configuration files
  • Python API parameters

License

This project is licensed under the MIT License - see the LICENSE file for details.

=============================================

ALI BAVARCHIEE

=============================================

| https://github.com/AliBavarchee/ |

| https://www.linkedin.com/in/ali-bavarchee-qip/ |

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

strixhoot-0.0.8.tar.gz (17.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

strixhoot-0.0.8-py3-none-any.whl (21.2 kB view details)

Uploaded Python 3

File details

Details for the file strixhoot-0.0.8.tar.gz.

File metadata

  • Download URL: strixhoot-0.0.8.tar.gz
  • Upload date:
  • Size: 17.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.12.7

File hashes

Hashes for strixhoot-0.0.8.tar.gz
Algorithm Hash digest
SHA256 dc64a255108d0ce0c22f69706353a30426dcfc0d41b92630beab03514316db7a
MD5 ff9d19eb4502729d7dfb25c39745953b
BLAKE2b-256 8bc9ff40c50aac124f0574e1de1ba922ddf91da15c047b40441b9ec4bd2a95f9

See more details on using hashes here.

File details

Details for the file strixhoot-0.0.8-py3-none-any.whl.

File metadata

  • Download URL: strixhoot-0.0.8-py3-none-any.whl
  • Upload date:
  • Size: 21.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.12.7

File hashes

Hashes for strixhoot-0.0.8-py3-none-any.whl
Algorithm Hash digest
SHA256 545bb4b554b8dc37906efd25bcbf8109026e0753843d3148bde8a83bad0e3c45
MD5 a79c308723fb4714ee1b08e93200cc1c
BLAKE2b-256 ca88bee0147407f882d7d12af874a9398f9ecd962b05b2638de0befa3aaf7974

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page