Skip to main content

Clustering-Informed Shared-Structure Variational Autoencoder (CISS-VAE) for Missing Data Imputation

Project description

CISS-VAE

Python implementation of the Clustering-Informed Shared-Structure Variational Autoencoder (CISS-VAE)

CISS-VAE is a flexible deep learning model for missing data imputation that accommodates all three types of missing data mechanisms: Missing Completely At Random (MCAR), Missing At Random (MAR), and Missing Not At Random (MNAR). While it is particularly well-suited to MNAR scenarios where missingness patterns carry informative signals, CISS-VAE also functions effectively under MAR assumptions. Please see our publication for more details.

Example CISS-VAE for Imputation Workflow

Click Here for More Information
A key feature of CISS-VAE is the use of unsupervised clustering to capture distinct patterns of missingness. Alongside cluster-specific representations, the method leverages shared encoder and decoder layers. This allows for knowledge transfer across clusters and enhances parameter stability, which is especially important when some clusters have small sample sizes. In situations where the data do not naturally partition into meaningful clusters, the model defaults to a pooled representation, preventing unnecessary complications from cluster-specific components.

Additionally, CISS-VAE incorporates an iterative learning procedure, with a validation-based convergence criterion recommended to avoid overfitting. This procedure significantly improves imputation accuracy compared to traditional Variational Autoencoder training approaches in the presence of missing values. Overall, CISS-VAE adapts across a range of missing data mechanisms, leveraging clustering only when it offers clear benefits, and delivering robust, accurate imputations under varying conditions of missingness.

Installation

The CISS-VAE package is currently available for python, with an R package to be released soon (see rCISSVAE github page for updates). It can be installed from either github or PyPI.

# From PyPI
pip install ciss-vae

OR

# From GitHub (latest development version)
pip install git+https://github.com/CISS-VAE/CISS-VAE-python.git

Quickstart Tutorial

The full vignette can be found here.

The input dataset should be one of the following:

- A Pandas DataFrame  

- A NumPy array  

- A PyTorch tensor  

Missing values should be represented using np.nan or None.

import pandas as pd
from ciss_vae.utils.run_cissvae import run_cissvae

# optional, display vae architecture
from ciss_vae.utils.helpers import plot_vae_architecture

data = pd.read_csv("/data/test_data.csv")

clusters = data.clusters
data = data.drop(columns = ["clusters", "Unnamed: 0"])

imputed_data, vae = run_cissvae(data = data,
## Dataset params
    val_proportion = 0.1, ## Fraction of non-missing data held out for validation
    replacement_value = 0.0, 
    columns_ignore = data.columns[:5], ## Names of columns to ignore when selecting validation dataset (and clustering if you do not provide clusters). For example, demographic columns with no missingness.
    print_dataset = True, 

## Cluster params
    clusters = None, ## Where your cluster list goes. If none, will do clustering for you  
    n_clusters = None, ## If you want run_cissvae to do clustering and you know how many clusters your data should have
    seed = 42,

## VAE model params
    hidden_dims = [150, 120, 60], ## Dimensions of hidden layers, in order. One number per layer. 
    latent_dim = 15, ## Dimensions of latent embedding
    layer_order_enc = ["unshared", "unshared", "unshared"], ## order of shared vs unshared layers for encode 
    layer_order_dec=["shared", "shared",  "shared"],  ## order of shared vs unshared layers for decode
    latent_shared=False, 
    output_shared=False, 
    batch_size = 4000, ## batch size for data loader
    return_model = True, ## if true, outputs imputed dataset and model, otherwise just outputs imputed dataset. Set to true to return model for `plot_vae_architecture`

## Initial Training params
    epochs = 1000, ## default 
    initial_lr = 0.01, ## default
    decay_factor = 0.999, ## default, factor learning rate is multiplied by after each epoch, prevents overfitting
    beta= 0.001, ## default
    device = None, ## If none, will use gpu if available, cpu if not. See torch.devices for info (link)

## Impute-refit loop params
    max_loops = 100, ## max number of refit loops
    patience = 2, ## number of loops to check after best_dataset updated. Can increase to avoid local extrema
    epochs_per_loop = None, ## If none, same as epochs
    initial_lr_refit = None, ## If none, picks up from end of initial training
    decay_factor_refit = None, ## If none, same as decay_factor
    beta_refit = None, ## if none, same as beta
    verbose = False
)

## OPTIONAL - PLOT VAE ARCHITECTURE
plot_vae_architecture(model = vae,
                        title = None, ## Set title of plot
                        ## Colors below are default
                        color_shared = "skyblue", 
                        color_unshared ="lightcoral",
                        color_latent = "gold", 
                        color_input = "lightgreen",
                        color_output = "lightgreen",
                        figsize=(16, 8),
                        return_fig = False)

Output of plot_vae_architecture

The CISS-VAE package includes the option to perform automated hyperparameter tuning with OPTUNA

See tutorial for more details.

Citation

If you use our package in your research, please consider citing our recent publication!

Y. Khadem Charvadeh, K. Seier, K. S. Panageas, D. Vaithilingam, M.Gönen, and Y. Chen, “Clustering-Informed Shared-Structure Variational Autoencoder for Missing Data Imputation in Large-Scale Healthcare Data,” Statistics in Medicine 44, no. 28-30 (2025): e70335, https://doi.org/10.1002/sim.70335.

Authors

  • Yasin Khadem Charvadeh
  • Kenneth Seier
  • Katherine S. Panageas
  • Danielle Vaithilingam
  • Mithat Gönen
  • Yuan Chen (corresponding author)cheny19@mskcc.org

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ciss_vae-1.1.0.tar.gz (64.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ciss_vae-1.1.0-py3-none-any.whl (59.0 kB view details)

Uploaded Python 3

File details

Details for the file ciss_vae-1.1.0.tar.gz.

File metadata

  • Download URL: ciss_vae-1.1.0.tar.gz
  • Upload date:
  • Size: 64.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for ciss_vae-1.1.0.tar.gz
Algorithm Hash digest
SHA256 cf766d3a23724579f56841c599406ce3f37773ea386ef3cbdd896f3673488151
MD5 a271c31a7e13476ed2d12a3c0774087a
BLAKE2b-256 703457e2962dbd21a52d3b931c813a02b9d21bcc97caea07cdfaf35857b9aabe

See more details on using hashes here.

Provenance

The following attestation bundles were made for ciss_vae-1.1.0.tar.gz:

Publisher: pypi.yml on CISS-VAE/CISS-VAE-python

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file ciss_vae-1.1.0-py3-none-any.whl.

File metadata

  • Download URL: ciss_vae-1.1.0-py3-none-any.whl
  • Upload date:
  • Size: 59.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for ciss_vae-1.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 fa77678e4719b743f194207d177beac91fe53ddd2fea917ff915fa90a2079d17
MD5 78458fa04800e8d257dd620b2dbe49fd
BLAKE2b-256 22dfe7a66e16919beb7533ee1fb65af4fd46456430c8156cbea8210a1b53ff42

See more details on using hashes here.

Provenance

The following attestation bundles were made for ciss_vae-1.1.0-py3-none-any.whl:

Publisher: pypi.yml on CISS-VAE/CISS-VAE-python

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page