Skip to main content

Clustering-Informed Shared-Structure Variational Autoencoder (CISS-VAE) for Missing Data Imputation

Project description

CISS-VAE

Python implementation of the Clustering-Informed Shared-Structure Variational Autoencoder (CISS-VAE)

CISS-VAE is a flexible deep learning model for missing data imputation that accommodates all three types of missing data mechanisms: Missing Completely At Random (MCAR), Missing At Random (MAR), and Missing Not At Random (MNAR). While it is particularly well-suited to MNAR scenarios where missingness patterns carry informative signals, CISS-VAE also functions effectively under MAR assumptions.

Example CISS-VAE for Imputation Workflow

Click Here for More Information
A key feature of CISS-VAE is the use of unsupervised clustering to capture distinct patterns of missingness. Alongside cluster-specific representations, the method leverages shared encoder and decoder layers. This allows for knowledge transfer across clusters and enhances parameter stability, which is especially important when some clusters have small sample sizes. In situations where the data do not naturally partition into meaningful clusters, the model defaults to a pooled representation, preventing unnecessary complications from cluster-specific components.

Additionally, CISS-VAE incorporates an iterative learning procedure, with a validation-based convergence criterion recommended to avoid overfitting. This procedure significantly improves imputation accuracy compared to traditional Variational Autoencoder training approaches in the presence of missing values. Overall, CISS-VAE adapts across a range of missing data mechanisms, leveraging clustering only when it offers clear benefits, and delivering robust, accurate imputations under varying conditions of missingness.

Installation

The CISS-VAE package is currently available for python, with an R package to be released soon (see rCISSVAE github page for updates). It can be installed from either github or PyPI.

# From PyPI
pip install ciss-vae

OR

# From GitHub (latest development version)
pip install git+https://github.com/CISS-VAE/CISS-VAE-python.git

Important!

For run_cissvae to be able to handle clustering, please install the clustering dependencies scikit-learn and leidenalg with pip.

pip install scikit-learn leidenalg python-igraph

OR

pip install ciss-vae[clustering]

Quickstart Tutorial

The full vignette can be found here.

The input dataset should be one of the following:

- A Pandas DataFrame  

- A NumPy array  

- A PyTorch tensor  

Missing values should be represented using np.nan or None.

import pandas as pd
from ciss_vae.utils.run_cissvae import run_cissvae

# optional, display vae architecture
from ciss_vae.utils.helpers import plot_vae_architecture

data = pd.read_csv("/data/test_data.csv")

clusters = data.clusters
data = data.drop(columns = ["clusters", "Unnamed: 0"])

imputed_data, vae = run_cissvae(data = data,
## Dataset params
    val_proportion = 0.1, ## Fraction of non-missing data held out for validation
    replacement_value = 0.0, 
    columns_ignore = data.columns[:5], ## Names of columns to ignore when selecting validation dataset (and clustering if you do not provide clusters). For example, demographic columns with no missingness.
    print_dataset = True, 

## Cluster params
    clusters = None, ## Where your cluster list goes. If none, will do clustering for you  
    n_clusters = None, ## If you want run_cissvae to do clustering and you know how many clusters your data should have
    seed = 42,

## VAE model params
    hidden_dims = [150, 120, 60], ## Dimensions of hidden layers, in order. One number per layer. 
    latent_dim = 15, ## Dimensions of latent embedding
    layer_order_enc = ["unshared", "unshared", "unshared"], ## order of shared vs unshared layers for encode 
    layer_order_dec=["shared", "shared",  "shared"],  ## order of shared vs unshared layers for decode
    latent_shared=False, 
    output_shared=False, 
    batch_size = 4000, ## batch size for data loader
    return_model = True, ## if true, outputs imputed dataset and model, otherwise just outputs imputed dataset. Set to true to return model for `plot_vae_architecture`

## Initial Training params
    epochs = 1000, ## default 
    initial_lr = 0.01, ## default
    decay_factor = 0.999, ## default, factor learning rate is multiplied by after each epoch, prevents overfitting
    beta= 0.001, ## default
    device = None, ## If none, will use gpu if available, cpu if not. See torch.devices for info (link)

## Impute-refit loop params
    max_loops = 100, ## max number of refit loops
    patience = 2, ## number of loops to check after best_dataset updated. Can increase to avoid local extrema
    epochs_per_loop = None, ## If none, same as epochs
    initial_lr_refit = None, ## If none, picks up from end of initial training
    decay_factor_refit = None, ## If none, same as decay_factor
    beta_refit = None, ## if none, same as beta
    verbose = False
)

## OPTIONAL - PLOT VAE ARCHITECTURE
plot_vae_architecture(model = vae,
                        title = None, ## Set title of plot
                        ## Colors below are default
                        color_shared = "skyblue", 
                        color_unshared ="lightcoral",
                        color_latent = "gold", 
                        color_input = "lightgreen",
                        color_output = "lightgreen",
                        figsize=(16, 8),
                        return_fig = False)

Output of plot_vae_architecture

The CISS-VAE package includes the option to perform automated hyperparameter tuning with OPTUNA

See tutorial for more details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ciss_vae-1.0.1.tar.gz (57.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ciss_vae-1.0.1-py3-none-any.whl (57.9 kB view details)

Uploaded Python 3

File details

Details for the file ciss_vae-1.0.1.tar.gz.

File metadata

  • Download URL: ciss_vae-1.0.1.tar.gz
  • Upload date:
  • Size: 57.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for ciss_vae-1.0.1.tar.gz
Algorithm Hash digest
SHA256 438f46335d8092d4aacea42b33e3696db901993d18c7219f8cefe649d01f1248
MD5 293f91c046772fde9f78f1f2c521f149
BLAKE2b-256 583a9af0c5781f0bdaccd6d9d73b0082f29d649776cb031a0f903e1bb2e55825

See more details on using hashes here.

Provenance

The following attestation bundles were made for ciss_vae-1.0.1.tar.gz:

Publisher: pypi.yml on CISS-VAE/CISS-VAE-python

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file ciss_vae-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: ciss_vae-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 57.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for ciss_vae-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 7d3fee7b522497ca4c11cdaeada385af1c156c650d362e98241ed6baf8ca72c5
MD5 276fa6fd180364e1dac86a4e0d6dd132
BLAKE2b-256 c22cf004d94e1629fd60a61a2f0df2b067c40246f17fd466e3743bfdf1dc28b9

See more details on using hashes here.

Provenance

The following attestation bundles were made for ciss_vae-1.0.1-py3-none-any.whl:

Publisher: pypi.yml on CISS-VAE/CISS-VAE-python

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page