Skip to main content

A hybrid generative model combining CVAE, CGAN, and Random Forest.

Project description

logo

LatentBoostClassifier: Hybrid Generative Model: CVAE + CGAN + RF

This repository implements a hybrid generative model+ combining a Conditional Variational Autoencoder (CVAE), a Conditional Generative Adversarial Network (CGAN), and a Random Forest Classifier (RFC). The hybrid model is designed for complex classification tasks by leveraging latent feature extraction (CVAE), synthetic data generation (CGAN), and robust classification (RFC).


Features

  1. Conditional Variational Autoencoder (CVAE):

    • Learns a latent representation of input data.
    • Uses a custom loss function combining reconstruction loss and KL divergence.
  2. Conditional Generative Adversarial Network (CGAN):

    • Generates high-quality synthetic data conditioned on class labels.
  3. Random Forest Classifier:

    • Trained on a combination of real and synthetic features to enhance classification performance.
  4. Parallel Model Training:

    • The CVAE and CGAN models are trained concurrently using Python's multiprocessing module.
  5. Hyperparameter Tuning:

    • Utilizes Keras Tuner for optimizing the CVAE and CGAN hyperparameters.
    • Random Forest is tuned using GridSearchCV.

Installation

Prerequisites

  • Python 3.7+
  • TensorFlow 2.8+
  • Scikit-learn
  • Keras Tuner
  • Matplotlib
  • Seaborn
  • Multiprocessing

Install Required Packages

pip install tensorflow keras-tuner scikit-learn matplotlib seaborn


Install the Package

the package can be installed the package via PyPI:

pip install LatentBoostClassifier

Alternatively, install directly from GitHub:

pip install git+https://github.com/AliBavarchee/LatentBoostClassifier.git

Usage


Using the Package

After installation, the package can be imported and used like this:

from LatentBoostClassifier import parallel_train, visualize_hybrid_model

# Train the hybrid model
best_cvae, best_cgan_generator, best_rf_model = parallel_train(X_train, Y_train, X_test, Y_test)

# Visualize results
visualize_hybrid_model(best_cvae, best_cgan_generator, best_rf_model, X_test, Y_test, X_train, Y_train)

Training the Hybrid Model

Train the CVAE, CGAN, and Random Forest models using your dataset:

# Import the training function
from hybrid_model import parallel_train

# Define training and testing datasets
X_train, Y_train = ...  # Load or preprocess your training data
X_test, Y_test = ...    # Load or preprocess your test data

# Train the hybrid model
best_cvae, best_cgan_generator, best_rf_model = parallel_train(X_train, Y_train, X_test, Y_test)

Visualizing Results

Evaluate and visualize the performance of the hybrid model:

from hybrid_model import visualize_hybrid_model

# Visualize results
visualize_hybrid_model(best_cvae, best_cgan_generator, best_rf_model, X_test, Y_test, X_train, Y_train)

Code Overview

1. CVAE

  • Purpose: Extract latent features from input data.
  • Key Components:
    • Custom loss function: Combines reconstruction loss and KL divergence.
    • Encoder-Decoder architecture.
  • Hyperparameter Tuning:
    • Latent dimensions.
    • Dense layer units.
    • Learning rate.

2. CGAN

  • Purpose: Generate synthetic data conditioned on class labels.
  • Key Components:
    • Generator: Produces synthetic samples.
    • Discriminator: Distinguishes real and synthetic samples.
  • Hyperparameter Tuning:
    • Latent dimensions.
    • Dense layer units.
    • Learning rate.

3. Random Forest

  • Purpose: Classify combined real and synthetic data.
  • Key Components:
    • Trained on latent features (CVAE) and synthetic data (CGAN).
    • Grid search for hyperparameter optimization.

Visualization

Classification Performance

  • Confusion Matrix: Visualize actual vs. predicted class distributions.
  • Classification Report: Display precision, recall, F1-score, and accuracy metrics.

Latent Space

  • Use t-SNE to visualize the CVAE's latent space.

Synthetic Data

  • Compare real and synthetic data distributions using KDE plots.

ROC Curve

  • Evaluate the Random Forest classifier with an ROC curve and compute AUC.

Key Functions

Model Training

parallel_train(X_train, Y_train, X_test, Y_test)
  • Trains the CVAE, CGAN, and Random Forest models in parallel.

Visualization

visualize_hybrid_model(best_cvae, best_cgan_generator, best_rf_model, X_test, Y_test, X_train, Y_train)
  • Visualizes the hybrid model's results, including latent space and classification performance.

Examples

Training Example

# Example Dataset
X_train = np.random.rand(80, 10)
Y_train = np.random.randint(0, 2, size=(80, 1))
X_test = np.random.rand(20, 10)
Y_test = np.random.randint(0, 2, size=(20, 1))

# Train the hybrid model
best_cvae, best_cgan_generator, best_rf_model = parallel_train(X_train, Y_train, X_test, Y_test)

Visualization Example

# Visualize results
visualize_hybrid_model(best_cvae, best_cgan_generator, best_rf_model, X_test, Y_test, X_train, Y_train)

Contributions

Feel free to fork, improve, and submit a pull request! For major changes, please open an issue first to discuss what you'd like to change.


License

This project is licensed under the MIT License. See the LICENSE file for details.

=============================================

logoS.png

=============================================

| https://github.com/AliBavarchee/ |

| https://www.linkedin.com/in/ali-bavarchee-qip/ |

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

LatentBoostClassifier-1.2.2.tar.gz (252.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

LatentBoostClassifier-1.2.2-py3-none-any.whl (250.0 kB view details)

Uploaded Python 3

File details

Details for the file LatentBoostClassifier-1.2.2.tar.gz.

File metadata

  • Download URL: LatentBoostClassifier-1.2.2.tar.gz
  • Upload date:
  • Size: 252.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.12.7

File hashes

Hashes for LatentBoostClassifier-1.2.2.tar.gz
Algorithm Hash digest
SHA256 51dc11ccb0eb6ab642a6337b1f143105f3035c1e1f39619eab49a22cd1cd7ae3
MD5 1b7a0dc3f5abe3bccc9e7dc3f4416386
BLAKE2b-256 2a07d349adc45f57b47996853bf573e24f251c930c41c2124df6fea86312868d

See more details on using hashes here.

File details

Details for the file LatentBoostClassifier-1.2.2-py3-none-any.whl.

File metadata

File hashes

Hashes for LatentBoostClassifier-1.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 951e827d88c022c1f864c5c9728c6b8faf98218751310a3818d399af7b692c58
MD5 307536e6454c6f17645ad9ffb9a63e25
BLAKE2b-256 a08d90aba7c4ea6ae436492909233a134b2b0f13c0bf7fbc78204c1c355e6cae

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page