Skip to main content

A hybrid generative model combining CVAE, CGAN, and Random Forest.

Project description

logo


   ||    // \\ | || | ||    ||\ || | || |
   ||    ||=||   ||   ||==  ||\\||   ||  
   ||__| || ||   ||   ||___ || \||   ||    
   ____    ___     ___    __  ______
      || ))  // \\   // \\  (( \ | || |
      ||=)  ((   )) ((   ))  \\    ||  
      ||_))  \\_//   \\_//  \_))   ||
        __                         
       /   |  _   _  _ . (_ .  _  _ 
       \__ | (_| _) _) | |  | (- |
         __       ___           
        /_ |     |__ \      |    | 
         | |        ) |     |    | 
         | |       / /      |    |  
         | |  __  / /_   __ |____|__
         |_| (__)|____| (__)     |


LatentBoostClassifier: Hybrid Generative Model: CVAE + CGAN + RF



| http://dx.doi.org/10.13140/RG.2.2.11351.79522 |

This repository implements a hybrid generative model+ combining a Conditional Variational Autoencoder (CVAE), a Conditional Generative Adversarial Network (CGAN), and a Random Forest Classifier (RFC). The hybrid model is designed for complex classification tasks by leveraging latent feature extraction (CVAE), synthetic data generation (CGAN), and robust classification (RFC).


Features

  1. Conditional Variational Autoencoder (CVAE):

    • Learns a latent representation of input data.
    • Uses a custom loss function combining reconstruction loss and KL divergence.
  2. Conditional Generative Adversarial Network (CGAN):

    • Generates high-quality synthetic data conditioned on class labels.
  3. Random Forest Classifier:

    • Trained on a combination of real and synthetic features to enhance classification performance.
  4. Parallel Model Training:

    • The CVAE and CGAN models are trained concurrently using Python's multiprocessing module.
  5. Hyperparameter Tuning:

    • Utilizes Keras Tuner for optimizing the CVAE and CGAN hyperparameters.
    • Random Forest is tuned using GridSearchCV.

Installation

Prerequisites

  • Python 3.7+
  • TensorFlow 2.8+
  • Scikit-learn
  • Keras Tuner
  • Matplotlib
  • Seaborn
  • Multiprocessing

Install Required Packages

pip install tensorflow keras-tuner scikit-learn matplotlib seaborn


Install the Package

the package can be installed the package via PyPI:

pip install LatentBoostClassifier

Alternatively, install directly from GitHub:

pip install git+https://github.com/AliBavarchee/LatentBoostClassifier.git

Usage


Using the Package

After installation, the package can be imported and used like this:

from LatentBoostClassifier import parallel_train, visualize_hybrid_model

# Train the hybrid model
best_cvae, best_cgan_generator, best_rf_model = parallel_train(X_train, Y_train, X_test, Y_test)

# Visualize results
visualize_hybrid_model(best_cvae, best_cgan_generator, best_rf_model, X_test, Y_test, X_train, Y_train)

Training the Hybrid Model

Train the CVAE, CGAN, and Random Forest models using your dataset:

# Import the training function
from LatentBoostClassifier import parallel_train

# Define training and testing datasets
X_train, Y_train = ...  # Load or preprocess your training data
X_test, Y_test = ...    # Load or preprocess your test data

# Train the hybrid model
best_cvae, best_cgan_generator, best_rf_model = parallel_train(X_train, Y_train, X_test, Y_test)

Visualizing Results

Evaluate and visualize the performance of the hybrid model:

from LatentBoostClassifier import visualize_hybrid_model

# Visualize results
visualize_hybrid_model(best_cvae, best_cgan_generator, best_rf_model, X_test, Y_test, X_train, Y_train)

Code Overview

1. CVAE

  • Purpose: Extract latent features from input data.
  • Key Components:
    • Custom loss function: Combines reconstruction loss and KL divergence.
    • Encoder-Decoder architecture.
  • Hyperparameter Tuning:
    • Latent dimensions.
    • Dense layer units.
    • Learning rate.

2. CGAN

  • Purpose: Generate synthetic data conditioned on class labels.
  • Key Components:
    • Generator: Produces synthetic samples.
    • Discriminator: Distinguishes real and synthetic samples.
  • Hyperparameter Tuning:
    • Latent dimensions.
    • Dense layer units.
    • Learning rate.

3. Random Forest

  • Purpose: Classify combined real and synthetic data.
  • Key Components:
    • Trained on latent features (CVAE) and synthetic data (CGAN).
    • Grid search for hyperparameter optimization.

Visualization

Classification Performance

  • Confusion Matrix: Visualize actual vs. predicted class distributions.
  • Classification Report: Display precision, recall, F1-score, and accuracy metrics.

Latent Space

  • Use t-SNE to visualize the CVAE's latent space.

Synthetic Data

  • Compare real and synthetic data distributions using KDE plots.

ROC Curve

  • Evaluate the Random Forest classifier with an ROC curve and compute AUC.

Key Functions

Model Training

parallel_train(X_train, Y_train, X_test, Y_test)
  • Trains the CVAE, CGAN, and Random Forest models in parallel.

Visualization

visualize_hybrid_model(best_cvae, best_cgan_generator, best_rf_model, X_test, Y_test, X_train, Y_train)
  • Visualizes the hybrid model's results, including latent space and classification performance.

Examples

Training Example

# Example Dataset
X_train = np.random.rand(80, 10)
Y_train = np.random.randint(0, 2, size=(80, 1))
X_test = np.random.rand(20, 10)
Y_test = np.random.randint(0, 2, size=(20, 1))

# Train the hybrid model
best_cvae, best_cgan_generator, best_rf_model = parallel_train(X_train, Y_train, X_test, Y_test)

Visualization Example

# Visualize results
visualize_hybrid_model(best_cvae, best_cgan_generator, best_rf_model, X_test, Y_test, X_train, Y_train)

Contributions

Feel free to fork, improve, and submit a pull request! For major changes, please open an issue first to discuss what you'd like to change.


License

This project is licensed under the MIT License. See the LICENSE file for details.

=============================================

ALI BAVARCHIEE

=============================================

| https://github.com/AliBavarchee/ |

| https://www.linkedin.com/in/ali-bavarchee-qip/ |

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

LatentBoostClassifier-1.2.4.tar.gz (413.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

LatentBoostClassifier-1.2.4-py3-none-any.whl (409.9 kB view details)

Uploaded Python 3

File details

Details for the file LatentBoostClassifier-1.2.4.tar.gz.

File metadata

  • Download URL: LatentBoostClassifier-1.2.4.tar.gz
  • Upload date:
  • Size: 413.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.12.7

File hashes

Hashes for LatentBoostClassifier-1.2.4.tar.gz
Algorithm Hash digest
SHA256 1be6a478f5bad17cb5b50fda9f6b8b3fc9f164c3770d698a325274eaa8bd68f7
MD5 541880832eb8d6b7ecf338e3382db604
BLAKE2b-256 591794448f8fac41d446e72a127bce172d120f3b655094b48ab6b87643246e3c

See more details on using hashes here.

File details

Details for the file LatentBoostClassifier-1.2.4-py3-none-any.whl.

File metadata

File hashes

Hashes for LatentBoostClassifier-1.2.4-py3-none-any.whl
Algorithm Hash digest
SHA256 75f011138ce3ad50182069341b15fbae55fbfdd86b9043ec2109365bb9d0fda3
MD5 a8c7faa42967fb3966ed4cfc2e02d404
BLAKE2b-256 25bd782030846cb462758f0ec2898569d93dc073641319f6d13e4803caa9adc9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page