A hybrid generative model combining CVAE, CGAN, and Random Forest.
Project description
|| // \\ | || | || ||\ || | || |
|| ||=|| || ||== ||\\|| ||
||__| || || || ||___ || \|| ||
____ ___ ___ __ ______
|| )) // \\ // \\ (( \ | || |
||=) (( )) (( )) \\ ||
||_)) \\_// \\_// \_)) ||
__
/ | _ _ _ . (_ . _ _
\__ | (_| _) _) | | | (- |
__ ___
/_ | |__ \ | |
| | ) | | |
| | / / | |
| | __ / /_ __ |____|__
|_| (__)|____| (__) |
LatentBoostClassifier: Hybrid Generative Model: CVAE + CGAN + RF
| http://dx.doi.org/10.13140/RG.2.2.11351.79522 |
This repository implements a hybrid generative model+ combining a Conditional Variational Autoencoder (CVAE), a Conditional Generative Adversarial Network (CGAN), and a Random Forest Classifier (RFC). The hybrid model is designed for complex classification tasks by leveraging latent feature extraction (CVAE), synthetic data generation (CGAN), and robust classification (RFC).
Features
-
Conditional Variational Autoencoder (CVAE):
- Learns a latent representation of input data.
- Uses a custom loss function combining reconstruction loss and KL divergence.
-
Conditional Generative Adversarial Network (CGAN):
- Generates high-quality synthetic data conditioned on class labels.
-
Random Forest Classifier:
- Trained on a combination of real and synthetic features to enhance classification performance.
-
Parallel Model Training:
- The CVAE and CGAN models are trained concurrently using Python's
multiprocessingmodule.
- The CVAE and CGAN models are trained concurrently using Python's
-
Hyperparameter Tuning:
- Utilizes Keras Tuner for optimizing the CVAE and CGAN hyperparameters.
- Random Forest is tuned using GridSearchCV.
Installation
Prerequisites
- Python 3.7+
- TensorFlow 2.8+
- Scikit-learn
- Keras Tuner
- Matplotlib
- Seaborn
- Multiprocessing
Install Required Packages
pip install tensorflow keras-tuner scikit-learn matplotlib seaborn
Install the Package
the package can be installed the package via PyPI:
pip install LatentBoostClassifier
Alternatively, install directly from GitHub:
pip install git+https://github.com/AliBavarchee/LatentBoostClassifier.git
Usage
Using the Package
After installation, the package can be imported and used like this:
from LatentBoostClassifier import parallel_train, visualize_hybrid_model
# Train the hybrid model
best_cvae, best_cgan_generator, best_rf_model = parallel_train(X_train, Y_train, X_test, Y_test)
# Visualize results
visualize_hybrid_model(best_cvae, best_cgan_generator, best_rf_model, X_test, Y_test, X_train, Y_train)
Training the Hybrid Model
Train the CVAE, CGAN, and Random Forest models using your dataset:
# Import the training function
from LatentBoostClassifier import parallel_train
# Define training and testing datasets
X_train, Y_train = ... # Load or preprocess your training data
X_test, Y_test = ... # Load or preprocess your test data
# Train the hybrid model
best_cvae, best_cgan_generator, best_rf_model = parallel_train(X_train, Y_train, X_test, Y_test)
Visualizing Results
Evaluate and visualize the performance of the hybrid model:
from LatentBoostClassifier import visualize_hybrid_model
# Visualize results
visualize_hybrid_model(best_cvae, best_cgan_generator, best_rf_model, X_test, Y_test, X_train, Y_train)
Code Overview
1. CVAE
- Purpose: Extract latent features from input data.
- Key Components:
- Custom loss function: Combines reconstruction loss and KL divergence.
- Encoder-Decoder architecture.
- Hyperparameter Tuning:
- Latent dimensions.
- Dense layer units.
- Learning rate.
2. CGAN
- Purpose: Generate synthetic data conditioned on class labels.
- Key Components:
- Generator: Produces synthetic samples.
- Discriminator: Distinguishes real and synthetic samples.
- Hyperparameter Tuning:
- Latent dimensions.
- Dense layer units.
- Learning rate.
3. Random Forest
- Purpose: Classify combined real and synthetic data.
- Key Components:
- Trained on latent features (CVAE) and synthetic data (CGAN).
- Grid search for hyperparameter optimization.
Visualization
Classification Performance
- Confusion Matrix: Visualize actual vs. predicted class distributions.
- Classification Report: Display precision, recall, F1-score, and accuracy metrics.
Latent Space
- Use t-SNE to visualize the CVAE's latent space.
Synthetic Data
- Compare real and synthetic data distributions using KDE plots.
ROC Curve
- Evaluate the Random Forest classifier with an ROC curve and compute AUC.
Key Functions
Model Training
parallel_train(X_train, Y_train, X_test, Y_test)
- Trains the CVAE, CGAN, and Random Forest models in parallel.
Visualization
visualize_hybrid_model(best_cvae, best_cgan_generator, best_rf_model, X_test, Y_test, X_train, Y_train)
- Visualizes the hybrid model's results, including latent space and classification performance.
Examples
Training Example
# Example Dataset
X_train = np.random.rand(80, 10)
Y_train = np.random.randint(0, 2, size=(80, 1))
X_test = np.random.rand(20, 10)
Y_test = np.random.randint(0, 2, size=(20, 1))
# Train the hybrid model
best_cvae, best_cgan_generator, best_rf_model = parallel_train(X_train, Y_train, X_test, Y_test)
Visualization Example
# Visualize results
visualize_hybrid_model(best_cvae, best_cgan_generator, best_rf_model, X_test, Y_test, X_train, Y_train)
Contributions
Feel free to fork, improve, and submit a pull request! For major changes, please open an issue first to discuss what you'd like to change.
License
This project is licensed under the MIT License. See the LICENSE file for details.
=============================================
| https://github.com/AliBavarchee/ |
| https://www.linkedin.com/in/ali-bavarchee-qip/ |
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file LatentBoostClassifier-1.2.4.tar.gz.
File metadata
- Download URL: LatentBoostClassifier-1.2.4.tar.gz
- Upload date:
- Size: 413.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.12.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1be6a478f5bad17cb5b50fda9f6b8b3fc9f164c3770d698a325274eaa8bd68f7
|
|
| MD5 |
541880832eb8d6b7ecf338e3382db604
|
|
| BLAKE2b-256 |
591794448f8fac41d446e72a127bce172d120f3b655094b48ab6b87643246e3c
|
File details
Details for the file LatentBoostClassifier-1.2.4-py3-none-any.whl.
File metadata
- Download URL: LatentBoostClassifier-1.2.4-py3-none-any.whl
- Upload date:
- Size: 409.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.12.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
75f011138ce3ad50182069341b15fbae55fbfdd86b9043ec2109365bb9d0fda3
|
|
| MD5 |
a8c7faa42967fb3966ed4cfc2e02d404
|
|
| BLAKE2b-256 |
25bd782030846cb462758f0ec2898569d93dc073641319f6d13e4803caa9adc9
|