Skip to main content

A hybrid generative model combining CVAE, CGAN, and Random Forest.

Project description

logo

LatentBoostClassifier: Hybrid Generative Model: CVAE + CGAN + RF

This repository implements a hybrid generative model+ combining a Conditional Variational Autoencoder (CVAE), a Conditional Generative Adversarial Network (CGAN), and a Random Forest Classifier (RFC). The hybrid model is designed for complex classification tasks by leveraging latent feature extraction (CVAE), synthetic data generation (CGAN), and robust classification (RFC).


Features

  1. Conditional Variational Autoencoder (CVAE):

    • Learns a latent representation of input data.
    • Uses a custom loss function combining reconstruction loss and KL divergence.
  2. Conditional Generative Adversarial Network (CGAN):

    • Generates high-quality synthetic data conditioned on class labels.
  3. Random Forest Classifier:

    • Trained on a combination of real and synthetic features to enhance classification performance.
  4. Parallel Model Training:

    • The CVAE and CGAN models are trained concurrently using Python's multiprocessing module.
  5. Hyperparameter Tuning:

    • Utilizes Keras Tuner for optimizing the CVAE and CGAN hyperparameters.
    • Random Forest is tuned using GridSearchCV.

Installation

Prerequisites

  • Python 3.7+
  • TensorFlow 2.8+
  • Scikit-learn
  • Keras Tuner
  • Matplotlib
  • Seaborn
  • Multiprocessing

Install Required Packages

pip install tensorflow keras-tuner scikit-learn matplotlib seaborn


Install the Package

the package can be installed the package via PyPI:

pip install LatentBoostClassifier

Alternatively, install directly from GitHub:

pip install git+https://github.com/AliBavarchee/LatentBoostClassifier.git

Usage


Using the Package

After installation, the package can be imported and used like this:

from LatentBoostClassifier import parallel_train, visualize_hybrid_model

# Train the hybrid model
best_cvae, best_cgan_generator, best_rf_model = parallel_train(X_train, Y_train, X_test, Y_test)

# Visualize results
visualize_hybrid_model(best_cvae, best_cgan_generator, best_rf_model, X_test, Y_test, X_train, Y_train)

Training the Hybrid Model

Train the CVAE, CGAN, and Random Forest models using your dataset:

# Import the training function
from hybrid_model import parallel_train

# Define training and testing datasets
X_train, Y_train = ...  # Load or preprocess your training data
X_test, Y_test = ...    # Load or preprocess your test data

# Train the hybrid model
best_cvae, best_cgan_generator, best_rf_model = parallel_train(X_train, Y_train, X_test, Y_test)

Visualizing Results

Evaluate and visualize the performance of the hybrid model:

from hybrid_model import visualize_hybrid_model

# Visualize results
visualize_hybrid_model(best_cvae, best_cgan_generator, best_rf_model, X_test, Y_test, X_train, Y_train)

Code Overview

1. CVAE

  • Purpose: Extract latent features from input data.
  • Key Components:
    • Custom loss function: Combines reconstruction loss and KL divergence.
    • Encoder-Decoder architecture.
  • Hyperparameter Tuning:
    • Latent dimensions.
    • Dense layer units.
    • Learning rate.

2. CGAN

  • Purpose: Generate synthetic data conditioned on class labels.
  • Key Components:
    • Generator: Produces synthetic samples.
    • Discriminator: Distinguishes real and synthetic samples.
  • Hyperparameter Tuning:
    • Latent dimensions.
    • Dense layer units.
    • Learning rate.

3. Random Forest

  • Purpose: Classify combined real and synthetic data.
  • Key Components:
    • Trained on latent features (CVAE) and synthetic data (CGAN).
    • Grid search for hyperparameter optimization.

Visualization

Classification Performance

  • Confusion Matrix: Visualize actual vs. predicted class distributions.
  • Classification Report: Display precision, recall, F1-score, and accuracy metrics.

Latent Space

  • Use t-SNE to visualize the CVAE's latent space.

Synthetic Data

  • Compare real and synthetic data distributions using KDE plots.

ROC Curve

  • Evaluate the Random Forest classifier with an ROC curve and compute AUC.

Key Functions

Model Training

parallel_train(X_train, Y_train, X_test, Y_test)
  • Trains the CVAE, CGAN, and Random Forest models in parallel.

Visualization

visualize_hybrid_model(best_cvae, best_cgan_generator, best_rf_model, X_test, Y_test, X_train, Y_train)
  • Visualizes the hybrid model's results, including latent space and classification performance.

Examples

Training Example

# Example Dataset
X_train = np.random.rand(80, 10)
Y_train = np.random.randint(0, 2, size=(80, 1))
X_test = np.random.rand(20, 10)
Y_test = np.random.randint(0, 2, size=(20, 1))

# Train the hybrid model
best_cvae, best_cgan_generator, best_rf_model = parallel_train(X_train, Y_train, X_test, Y_test)

Visualization Example

# Visualize results
visualize_hybrid_model(best_cvae, best_cgan_generator, best_rf_model, X_test, Y_test, X_train, Y_train)

Contributions

Feel free to fork, improve, and submit a pull request! For major changes, please open an issue first to discuss what you'd like to change.


License

This project is licensed under the MIT License. See the LICENSE file for details.

=============================================

logoS.png

=============================================

| https://github.com/AliBavarchee/ |

| https://www.linkedin.com/in/ali-bavarchee-qip/ |

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

LatentBoostClassifier-1.0.0.tar.gz (11.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

LatentBoostClassifier-1.0.0-py3-none-any.whl (10.5 kB view details)

Uploaded Python 3

File details

Details for the file LatentBoostClassifier-1.0.0.tar.gz.

File metadata

  • Download URL: LatentBoostClassifier-1.0.0.tar.gz
  • Upload date:
  • Size: 11.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.12.7

File hashes

Hashes for LatentBoostClassifier-1.0.0.tar.gz
Algorithm Hash digest
SHA256 d020dc9e811bc61110ce637e047636910af4be3488f637e8a76eb24cffdc1e1a
MD5 ba3efe705f24b9d80880b0fd96168e41
BLAKE2b-256 16f1111ebfef9e29963ed3e8c12d903d45b0c4de566f63aff1ef4bbd12a3d75c

See more details on using hashes here.

File details

Details for the file LatentBoostClassifier-1.0.0-py3-none-any.whl.

File metadata

File hashes

Hashes for LatentBoostClassifier-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 34728752c854b70f130dbef14b1360238d8cc3d4eef48c1ae471c008e4dae434
MD5 25cce00d1b4d12aac8942f22a43d0bb1
BLAKE2b-256 3d1564be8249d14103f2aed628a371ef2363b2bfc3d23715d78969e4db03d7a7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page