Skip to main content

Pytorch/TF1 implementation of Variational AutoEncoder for anomaly detection following the paper "Variational Autoencoder based Anomaly Detection using Reconstruction Probability by Jinwon An, Sungzoon Cho"

Project description

Variational autoencoder for anomaly detection

PyPI PyPI - Python Version PyPI - License PyPI - Downloads

Pytorch/TF1 implementation of Variational AutoEncoder for anomaly detection following the paper Variational Autoencoder based Anomaly Detection using Reconstruction Probability by Jinwon An, Sungzoon Cho

How to install

Python package way

pip package containing the model and training_step only

    pip install vae-anomaly-detection

Hack this repository

a. Clone the repo

git clone git@github.com:Michedev/VAE_anomaly_detection.git

b. Install hatch

pip install hatch

c. Make the environment with torch gpu support

hatch env create

or with cpu support

hatch env create cpu

d. Run the train

hatch run train

or in cpu

hatch run cpu:train

To know all the train parameters run anaconda-project run train --help

This version contains the model and the training procedure

How To Train your Model

  • Define your dataset into dataset.py and overwrite the line train_set = rand_dataset() # set here your dataset in train.py
  • Subclass VAEAnomalyDetection and define the methods make_encoder and make_decoder. The output of make_encoder should be a flat vector while the output of `make_decoder should have the same shape of the input.

Make your model

Subclass VAEAnomalyDetection and define your encoder and decoder like in VaeAnomalyTabular

class VAEAnomalyTabular(VAEAnomalyDetection):

    def make_encoder(self, input_size, latent_size):
        """
        Simple encoder for tabular data.
        If you want to feed image to a VAE make another encoder function with Conv2d instead of Linear layers.
        :param input_size: number of input variables
        :param latent_size: number of output variables i.e. the size of the latent space since it's the encoder of a VAE
        :return: The untrained encoder model
        """
        return nn.Sequential(
            nn.Linear(input_size, 500),
            nn.ReLU(),
            nn.Linear(500, 200),
            nn.ReLU(),
            nn.Linear(200, latent_size * 2)
            # times 2 because this is the concatenated vector of latent mean and variance
        )

    def make_decoder(self, latent_size, output_size):
        """
        Simple decoder for tabular data.
        :param latent_size: size of input latent space
        :param output_size: number of output parameters. Must have the same value of input_size
        :return: the untrained decoder
        """
        return nn.Sequential(
            nn.Linear(latent_size, 200),
            nn.ReLU(),
            nn.Linear(200, 500),
            nn.ReLU(),
            nn.Linear(500, output_size * 2)  # times 2 because this is the concatenated vector of reconstructed mean and variance
        )

How to make predictions:

Once the model is trained (suppose for simplicity that it is under saved_models/{train-datetime}/ ) just load and predict with this code snippet:

import torch

#load X_test
model = VaeAnomalyTabular.load_checkpoint('saved_models/2022-01-06_15-12-23/last.ckpt')
# load saved parameters from a run
outliers = model.is_anomaly(X_test)

train.py help

    usage: train.py [-h] --input-size INPUT_SIZE --latent-size LATENT_SIZE
                    [--num-resamples NUM_RESAMPLES] [--epochs EPOCHS] [--batch-size BATCH_SIZE]
                    [--device {cpu,gpu,tpu}] [--lr LR] [--no-progress-bar]
                    [--steps-log-loss STEPS_LOG_LOSS]
                    [--steps-log-norm-params STEPS_LOG_NORM_PARAMS]

    options:
    -h, --help            show this help message and exit
    --input-size INPUT_SIZE, -i INPUT_SIZE
                            Number of input features. In 1D case it is the vector length, in 2D
                            case it is the number of channels
    --latent-size LATENT_SIZE, -l LATENT_SIZE
                            Size of the latent space
    --num-resamples NUM_RESAMPLES, -L NUM_RESAMPLES
                            Number of resamples in the latent distribution during training
    --epochs EPOCHS, -e EPOCHS
                            Number of epochs to train for
    --batch-size BATCH_SIZE, -b BATCH_SIZE
    --device {cpu,gpu,tpu}, -d {cpu,gpu,tpu}, --accelerator {cpu,gpu,tpu}
                            Device to use for training. Can be cpu, gpu or tpu
    --lr LR               Learning rate
    --no-progress-bar
    --steps-log-loss STEPS_LOG_LOSS
                            Number of steps between each loss logging
    --steps-log-norm-params STEPS_LOG_NORM_PARAMS
                            Number of steps between each model parameters logging

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vae_anomaly_detection-2.0.1.tar.gz (9.4 kB view details)

Uploaded Source

Built Distribution

vae_anomaly_detection-2.0.1-py3-none-any.whl (9.6 kB view details)

Uploaded Python 3

File details

Details for the file vae_anomaly_detection-2.0.1.tar.gz.

File metadata

File hashes

Hashes for vae_anomaly_detection-2.0.1.tar.gz
Algorithm Hash digest
SHA256 f0cff4fb1749d74ef6d9d68c112a55faff1fe34184ccc26d2847eb81c5de9a7d
MD5 3b604854c6c3cd380fc8245a90084cac
BLAKE2b-256 cf1bfa393ddca774665dcdd7da40fabe77acda1dff984e1aac43042a3e792f37

See more details on using hashes here.

File details

Details for the file vae_anomaly_detection-2.0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for vae_anomaly_detection-2.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 f6624f3a8cc3a75e3f423ba0d070041de38796045c1178b6560cebc013b692a0
MD5 bb0cdfc50360cf73d75b23dc056eddfd
BLAKE2b-256 70b6ba45ebaaecbd6029feae5ae18c4d88212edaaa9fb5e7957dbbed3bf54d87

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page