Clustering-Informed Shared-Structure Variational Autoencoder (CISS-VAE) for Missing Data Imputation
Project description
CISS-VAE
Python implementation of the Clustering-Informed Shared-Structure Variational Autoencoder (CISS-VAE)
CISS-VAE is a flexible deep learning model for missing data imputation that accommodates all three types of missing data mechanisms: Missing Completely At Random (MCAR), Missing At Random (MAR), and Missing Not At Random (MNAR). While it is particularly well-suited to MNAR scenarios where missingness patterns carry informative signals, CISS-VAE also functions effectively under MAR assumptions.
Click Here for More Information
Additionally, CISS-VAE incorporates an iterative learning procedure, with a validation-based convergence criterion recommended to avoid overfitting. This procedure significantly improves imputation accuracy compared to traditional Variational Autoencoder training approaches in the presence of missing values. Overall, CISS-VAE adapts across a range of missing data mechanisms, leveraging clustering only when it offers clear benefits, and delivering robust, accurate imputations under varying conditions of missingness.
Installation
The CISS-VAE package is currently available for python, with an R package to be released soon (see rCISSVAE github page for updates). It can be installed from either github or PyPI.
# From PyPI
pip install ciss-vae
OR
# From GitHub (latest development version)
pip install git+https://github.com/CISS-VAE/CISS-VAE-python.git
Important!
For run_cissvae to be able to handle clustering, please install the clustering dependencies scikit-learn and leidenalg with pip.
pip install scikit-learn leidenalg python-igraph OR pip install ciss-vae[clustering]
Quickstart Tutorial
The full vignette can be found here.
The input dataset should be one of the following:
- A Pandas DataFrame
- A NumPy array
- A PyTorch tensor
Missing values should be represented using np.nan or None.
import pandas as pd
from ciss_vae.utils.run_cissvae import run_cissvae
# optional, display vae architecture
from ciss_vae.utils.helpers import plot_vae_architecture
data = pd.read_csv("/data/test_data.csv")
clusters = data.clusters
data = data.drop(columns = ["clusters", "Unnamed: 0"])
imputed_data, vae = run_cissvae(data = data,
## Dataset params
val_proportion = 0.1, ## Fraction of non-missing data held out for validation
replacement_value = 0.0,
columns_ignore = data.columns[:5], ## Names of columns to ignore when selecting validation dataset (and clustering if you do not provide clusters). For example, demographic columns with no missingness.
print_dataset = True,
## Cluster params
clusters = None, ## Where your cluster list goes. If none, will do clustering for you
n_clusters = None, ## If you want run_cissvae to do clustering and you know how many clusters your data should have
seed = 42,
## VAE model params
hidden_dims = [150, 120, 60], ## Dimensions of hidden layers, in order. One number per layer.
latent_dim = 15, ## Dimensions of latent embedding
layer_order_enc = ["unshared", "unshared", "unshared"], ## order of shared vs unshared layers for encode
layer_order_dec=["shared", "shared", "shared"], ## order of shared vs unshared layers for decode
latent_shared=False,
output_shared=False,
batch_size = 4000, ## batch size for data loader
return_model = True, ## if true, outputs imputed dataset and model, otherwise just outputs imputed dataset. Set to true to return model for `plot_vae_architecture`
## Initial Training params
epochs = 1000, ## default
initial_lr = 0.01, ## default
decay_factor = 0.999, ## default, factor learning rate is multiplied by after each epoch, prevents overfitting
beta= 0.001, ## default
device = None, ## If none, will use gpu if available, cpu if not. See torch.devices for info (link)
## Impute-refit loop params
max_loops = 100, ## max number of refit loops
patience = 2, ## number of loops to check after best_dataset updated. Can increase to avoid local extrema
epochs_per_loop = None, ## If none, same as epochs
initial_lr_refit = None, ## If none, picks up from end of initial training
decay_factor_refit = None, ## If none, same as decay_factor
beta_refit = None, ## if none, same as beta
verbose = False
)
## OPTIONAL - PLOT VAE ARCHITECTURE
plot_vae_architecture(model = vae,
title = None, ## Set title of plot
## Colors below are default
color_shared = "skyblue",
color_unshared ="lightcoral",
color_latent = "gold",
color_input = "lightgreen",
color_output = "lightgreen",
figsize=(16, 8),
return_fig = False)
The CISS-VAE package includes the option to perform automated hyperparameter tuning with OPTUNA
See tutorial for more details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file ciss_vae-1.0.1.tar.gz.
File metadata
- Download URL: ciss_vae-1.0.1.tar.gz
- Upload date:
- Size: 57.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
438f46335d8092d4aacea42b33e3696db901993d18c7219f8cefe649d01f1248
|
|
| MD5 |
293f91c046772fde9f78f1f2c521f149
|
|
| BLAKE2b-256 |
583a9af0c5781f0bdaccd6d9d73b0082f29d649776cb031a0f903e1bb2e55825
|
Provenance
The following attestation bundles were made for ciss_vae-1.0.1.tar.gz:
Publisher:
pypi.yml on CISS-VAE/CISS-VAE-python
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
ciss_vae-1.0.1.tar.gz -
Subject digest:
438f46335d8092d4aacea42b33e3696db901993d18c7219f8cefe649d01f1248 - Sigstore transparency entry: 722612719
- Sigstore integration time:
-
Permalink:
CISS-VAE/CISS-VAE-python@4ecfaff7daba787a045c2cf7a09b55e33f44eff5 -
Branch / Tag:
refs/tags/v.1.0.1 - Owner: https://github.com/CISS-VAE
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
pypi.yml@4ecfaff7daba787a045c2cf7a09b55e33f44eff5 -
Trigger Event:
release
-
Statement type:
File details
Details for the file ciss_vae-1.0.1-py3-none-any.whl.
File metadata
- Download URL: ciss_vae-1.0.1-py3-none-any.whl
- Upload date:
- Size: 57.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7d3fee7b522497ca4c11cdaeada385af1c156c650d362e98241ed6baf8ca72c5
|
|
| MD5 |
276fa6fd180364e1dac86a4e0d6dd132
|
|
| BLAKE2b-256 |
c22cf004d94e1629fd60a61a2f0df2b067c40246f17fd466e3743bfdf1dc28b9
|
Provenance
The following attestation bundles were made for ciss_vae-1.0.1-py3-none-any.whl:
Publisher:
pypi.yml on CISS-VAE/CISS-VAE-python
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
ciss_vae-1.0.1-py3-none-any.whl -
Subject digest:
7d3fee7b522497ca4c11cdaeada385af1c156c650d362e98241ed6baf8ca72c5 - Sigstore transparency entry: 722612875
- Sigstore integration time:
-
Permalink:
CISS-VAE/CISS-VAE-python@4ecfaff7daba787a045c2cf7a09b55e33f44eff5 -
Branch / Tag:
refs/tags/v.1.0.1 - Owner: https://github.com/CISS-VAE
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
pypi.yml@4ecfaff7daba787a045c2cf7a09b55e33f44eff5 -
Trigger Event:
release
-
Statement type: