Skip to main content

TFAE it's a TensorFlow extension for building and training different types of autoencoders

Project description

TensorFlow Autoencoders (TFAE)

This package is a TensorFlow extension for building and training different types of autoencoders.

Table of contents

Motivation

Autoencoders – AE – are widely used in various tasks: from computer vision to recommender systems and many others.

Tensorflow, being a flexible tool, does not provide AE-specific tools though.

The lack of such tools results into redundant code. Even official tutorial on this topic includes explicit looping through training steps and manual computation of losses and gradients.

Code like this is difficult to maintain.

TFAE makes process of building and training AEs easy yet flexible, keeping traditional TensorFlow style.

Current package

TFAE is a set of extensions for standard Tensorflow classes, such as Layers and Models. These can be used to build a wide range of models: shallow and deep, classical and variational.

In addition to layers and models, it contains extensions for Regularizers and Callbacks, which are used to control the training process, such as β-VAE, Cyclical β-annealing and others.

Installation

TFAE can be installed directly from PyPI:

pip install tfae

TensorFlow is the only dependency. Python 3.8-3.10 is required.

Usage

Let's take a quick example.

Here we build and train shallow variational autoencoder regularized with KL-divergence:

import tensorflow as tf

from tfae.models import Autoencoder
from tfae.bottlenecks import GaussianBottleneck
from tfae.regularizers import GaussianKLDRegularizer

model = Autoencoder(
    bottleneck=GaussianBottleneck(
        latent_dim=32,
        kernel_regularizer=GaussianKLDRegularizer(),
    ),
)

model.compile(...)

model.fit(...)

Now we can use this model to encode data:

encoded = model.encoder.predict(...)

Or generate new samples:

generated = model.decoder.predict(...)

Of course, it's possible to build deeper models and sophisticate training process.

Check out more examples:

or explore API in the following section.

API reference

TFAE includes:

Bottlenecks

Bottlenecks are layers placed in the middle of a model, connecting encoder with decoder.

Every bottleneck extends BaseBottleneck, which in turn extends tf.keras.layers.Layer.

Vanilla bottleneck

VanillaBottleneck it's a "semantic stub" for a dense layer – the most simple autoencoder bottleneck.

Variational bottlenecks

VariationalBottleneck is a subclass for building model implementing variational inference. TFAE includes two bottlenecks for variational inference: Gaussian and Bernoulli:

classDiagram
    `tf.keras.layers.Layer` <|-- BaseBottleneck
    BaseBottleneck <|-- VanillaBottleneck
    BaseBottleneck <|-- VariationalBottleneck
    VariationalBottleneck <|-- GaussianBottleneck
    VariationalBottleneck <|-- BernoulliBottleneck

    BaseBottleneck : int latent_dim
    VariationalBottleneck : int parameters

Models

TFAE includes two subclasses of tf.keras.Model: Autoencoder and DeepAutoencoder.

classDiagram
    `tf.keras.Model` <|-- Autoencoder
    Autoencoder <|-- DeepAutoencoder

    Autoencoder : BaseBottleneck bottleneck
    Autoencoder : make_encoder()
    Autoencoder : make_decoder()
    DeepAutoencoder : Callable add_hidden

Autoencoder

Autoencoder represents a simplest form of autoencoder with only one hidden layer as a bottleneck. See usage example.

DeepAutoencoder

DeepAutoencoder extends Autoencoder and allows to build deeper models in a functional way: it's add_hidden method constructs additional hidden layers.

Let's take a quick example how add_hidden works.

It takes four parameters:

  • input layer
  • number of current layer
  • shape of the input layer
  • dimensionality of the bottleneck

And returns a tuple of a new layer and a boolean indicating that current layer is the last.

This method is applied to both, encoder and decoder (but for decoder in a "mirror manner").

The following example demostrates how to create encoder and decoder with two hidden layers each. And both have a pyramidal structure:

from tfae.models import DeepAutoencoder

def add_hidden(
    x: tf.keras.layers.Layer,
    layer_num: int,
    input_shape: tf.TensorShape,
    latent_dim: int,
) -> tuple[tf.keras.layers.Layer, bool]:

    number_of_hidden_layers = 2

    divisor = (latent_dim / input_shape[-1]) ** (layer_num / (number_of_hidden_layers + 1))
    units = int(divisor * input_shape[-1])

    x = tf.keras.layers.Dense(units)(x)

    return x, layer_num == number_of_hidden_layers

model = DeepAutoencoder(
    bottleneck=...
    add_hidden=add_hidden,
)

Custom models

Custom models can be made by extending Autoencoder class. See an example.

Regularizers

It often proves useful to regularize bottleneck, so encoder could learn better and disentangled representation.

TFAE includes:

  • L2Regularizer for VanillaBottleneck
  • GaussianKLDRegularizer and GaussianReversedKLDRegularizer for GaussianBottleneck

Every TFAE regularizer extends BaseRegularizer, which contains property beta: float – regularization factor:

classDiagram
    `tf.keras.regularizers.Regularizer` <|-- BaseRegularizer
    BaseRegularizer <|-- L2Regularizer
    BaseRegularizer <|-- GaussianKLDRegularizer
    BaseRegularizer <|-- GaussianReversedKLDRegularizer

    BaseRegularizer: float beta

A custom regulirizer can be applied by extending BaseRegularizer.

Schedulers

Recent papers has shown that constant regularization factor can be an obstacle on the way to the better latent representation:

All these papers are suggesting to vary regularization factor – let's call it β – over time.

TFAE contains DASRScheduler which can handle different schedules covering the aforementioned papers.

Every scheduler extends BaseScheduler which extends tf.keras.callbacks.Callback:

classDiagram
    `tf.keras.callbacks.Callback` <|-- BaseScheduler
    BaseScheduler <|-- DASRScheduler

    BaseScheduler: calc()

    DASRScheduler: float start_value
    DASRScheduler: float end_value
    DASRScheduler: int delay
    DASRScheduler: int attack
    DASRScheduler: int sustain
    DASRScheduler: int release
    DASRScheduler: int cycles

DASRScheduler

"DASR" stands for Delay, Attack, Sustain, Release.

Let's say in "phase 1" we want to keep β=0 for first 5 epochs, then gradually rise it up to β=1 for 10 more epochs. In "phase 2" we'll keep β=1 until loss stops improving:

from tfae.models import Autoencoder
from tfae.bottlenecks import GaussianBottleneck
from tfae.regularizers import GaussianKLDRegularizer
from tfae.schedulers import DASRScheduler

# Creating scheduler which will keep β=0 for 5 epochs
# and then gradually raise it up to β=1 for 10 more epochs:
scheduler = DASRScheduler(
    start_value=0.0,
    end_value=1.0,
    delay=5,
    attack=10,
)

# Note how we pass scheduler.value to the regularizer:
model = Autoencoder(
    bottleneck=GaussianBottleneck(
        latent_dim=32,
        kernel_regulirizer=GaussianKLDRegularizer(
            beta=scheduler.value,
        )
    )
)

model.compile(...)

# Phase 1.
#
# Scheduler have auto-calculated attribute "duration"
# which tells how many epochs it takes
# to go through all scheduled values of β.
# 
# We also pass scheduler as a callback
# so so that he can be updated:
model.fit(
    ...
    epochs=scheduler.duration,
    callbacks=[
        scheduler,
    ],
)

# Phase 2.
#
# Here we continue training until loss stops improving:
model.fit(
    ...,
    initial_epoch=scheduler.duration,
    callbacks=[
        tf.keras.callbacks.EarlyStopping(...),
    ]
)

Let's take one more example and implement schedule for cyclical β-annealing:

scheduler = DASRScheduler(
    start_value=0.0,
    end_value=1.0,
    attack=10,
    sustain=10,
    cycles=4,
)

Contribution

Feel free to create issues or open pull-requests.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tfae-1.0.2.tar.gz (12.6 kB view hashes)

Uploaded Source

Built Distribution

tfae-1.0.2-py3-none-any.whl (12.7 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page