Skip to main content

Vae disentanglement framework built with pytorch lightning.

Project description

🧶 Disent

A modular disentangled representation learning framework built with PyTorch Lightning

license python versions pypi version tests status

Visit the docs for more info, or browse the releases.

Contributions are welcome!


⚠️ Development of official pypi versions has been paused temporarily until I finish my degree. Large changes can be expected afterwards for release 0.2.0, fixing most issues with the library.


Table Of Contents


Overview

Disent is a modular disentangled representation learning framework for auto-encoders, built upon PyTorch-Lightning. This framework consists of various composable components that can be used to build and benchmark various disentanglement vision tasks.

The name of the framework is derived from both disentanglement and scientific dissent.

Get started with disent by installing it with $pip install disent or cloning this repository.

Goals

Disent aims to fill the following criteria:

  1. Provide high quality, readable, consistent and easily comparable implementations of frameworks
  2. Highlight difference between framework implementations by overriding hooks and minimising duplicate code
  3. Use best practice eg. torch.distributions
  4. Be extremely flexible & configurable
  5. Support low memory systems

Citing Disent

Please use the following citation if you use Disent in your own research:

@Misc{Michlo2021Disent,
  author =       {Nathan Juraj Michlo},
  title =        {Disent - A modular disentangled representation learning framework for pytorch},
  howpublished = {Github},
  year =         {2021},
  url =          {https://github.com/nmichlo/disent}
}

Architecture

The disent directory structure:

  • disent/dataset: dataset wrappers, datasets & sampling strategies
    • disent/dataset/data: raw datasets
    • disent/dataset/sampling: sampling strategies for DisentDataset
  • disent/framework: frameworks, including Auto-Encoders and VAEs
  • disent/metric: metrics for evaluating disentanglement using ground truth datasets
  • disent/model: common encoder and decoder models used for VAE research
  • disent/nn: torch components for building models including layers, transforms, losses and general maths
  • disent/schedule: annealing schedules that can be registered to a framework
  • disent/util: helper classes, functions, callbacks, anything unrelated to a pytorch system/model/framework.

Please Note The API Is Still Unstable ⚠️

Disent is still under active development. Features and APIs are not considered stable, and should be expected to change! A limited set of tests currently exist which will be expanded upon in time.

Hydra Experiment Directories

Easily run experiments with hydra config, these files are not available from pip install.

  • experiment/run.py: entrypoint for running basic experiments with hydra config
  • experiment/config: root folder for hydra config files
  • experiment/util: various helper code for experiments

Features

Disent includes implementations of modules, metrics and datasets from various papers. Please note that items marked with a "🧵" are introduced in and are unique to disent!

Frameworks

Many popular disentanglement frameworks still need to be added, please submit an issue if you have a request for an additional framework.

todo

  • FactorVAE
  • GroupVAE
  • MLVAE

Metrics

Some popular metrics still need to be added, please submit an issue if you wish to add your own, or you have a request.

todo

Datasets

Various common datasets used in disentanglement research are included, with hash verification and automatic chunk-size optimization of underlying hdf5 formats for low-memory disk-based access.

  • Ground Truth:

    • Cars3D
    • dSprites
    • MPI3D
    • SmallNORB
    • Shapes3D
  • Ground Truth Synthetic:

    • 🧵 XYObject: A simplistic version of dSprites with a single square.

    XYObject Dataset Factor Traversals

    Input Transforms + Input/Target Augmentations

    • Input based transforms are supported.
    • Input and Target CPU and GPU based augmentations are supported.

Schedules & Annealing

Hyper-parameter annealing is supported through the use of schedules. The currently implemented schedules include:

  • Linear Schedule
  • Cyclic Schedule
  • Cosine Wave Schedule
  • Various other wrapper schedules

Examples

Python Example

The following is a basic working example of disent that trains a BetaVAE with a cyclic beta schedule and evaluates the trained model with various metrics.

Basic Example

import os
import pytorch_lightning as pl
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from disent.dataset import DisentDataset
from disent.dataset.data import XYObjectData
from disent.dataset.sampling import SingleSampler
from disent.frameworks.vae import BetaVae
from disent.metrics import metric_dci, metric_mig
from disent.model import AutoEncoder
from disent.model.ae import DecoderConv64, EncoderConv64
from disent.nn.transform import ToStandardisedTensor
from disent.schedule import CyclicSchedule

# create the dataset & dataloaders
# - ToStandardisedTensor transforms images from numpy arrays to tensors and performs checks
data = XYObjectData()
dataset = DisentDataset(dataset=data, sampler=SingleSampler(), transform=ToStandardisedTensor())
dataloader = DataLoader(dataset=dataset, batch_size=128, shuffle=True, num_workers=os.cpu_count())

# create the BetaVAE model
# - adjusting the beta, learning rate, and representation size.
module = BetaVae(
  model=AutoEncoder(
    # z_multiplier is needed to output mu & logvar when parameterising normal distribution
    encoder=EncoderConv64(x_shape=data.x_shape, z_size=10, z_multiplier=2),
    decoder=DecoderConv64(x_shape=data.x_shape, z_size=10),
  ),
  cfg=BetaVae.cfg(
    optimizer='adam', optimizer_kwargs=dict(lr=1e-3),
    loss_reduction='mean_sum', beta=4,
  )
)

# cyclic schedule for target 'beta' in the config/cfg. The initial value from the
# config is saved and multiplied by the ratio from the schedule on each step.
# - based on: https://arxiv.org/abs/1903.10145
module.register_schedule(
  'beta', CyclicSchedule(
    period=1024,  # repeat every: trainer.global_step % period
  )
)

# train model
# - for 2048 batches/steps
trainer = pl.Trainer(max_steps=2048, gpus=1 if torch.cuda.is_available() else None, logger=False, checkpoint_callback=False)
trainer.fit(module, dataloader)

# compute disentanglement metrics
# - we cannot guarantee which device the representation is on
# - this will take a while to run
get_repr = lambda x: module.encode(x.to(module.device))

metrics = {
  **metric_dci(dataset, get_repr, num_train=1000, num_test=500, show_progress=True),
  **metric_mig(dataset, get_repr, num_train=2000),
}

# evaluate
print('metrics:', metrics)

Visit the docs for more examples!

Hydra Config Example

The entrypoint for basic experiments is experiment/run.py.

Some configuration will be required, but basic experiments can be adjusted by modifying the Hydra Config 1.0 files in experiment/config (Please note that hydra 1.1 is not yet supported).

Modifying the main experiment/config/config.yaml is all you need for most basic experiments. The main config file contains a defaults list with entries corresponding to yaml configuration files (config options) in the subfolders (config groups) in experiment/config/<config_group>/<option>.yaml.

Config Defaults Example

defaults:
  # system
  - framework: adavae
  - model: vae_conv64
  - optimizer: adam
  - schedule: none
  # data
  - dataset: xyobject
  - dataset_sampling: full_bb
  - augment: none
  # runtime
  - metrics: fast
  - run_length: short
  - run_location: local
  - run_callbacks: vis
  - run_logging: wandb

# <rest of config.yaml left out>
...

Easily modify any of these values to adjust how the basic experiment will be run. For example, change framework: adavae to framework: betavae, or change the dataset from xyobject to shapes3d. Add new options by adding new yaml files in the config group folders.

Weights and Biases is supported by changing run_logging: none to run_logging: wandb. However, you will need to login from the command line. W&B logging supports visualisations of latent traversals.


Why?

  • Created as part of my Computer Science MSc scheduled for completion in 2021.
  • I needed custom high quality implementations of various VAE's.
  • A pytorch version of disentanglement_lib.
  • I didn't have time to wait for Weakly-Supervised Disentanglement Without Compromises to release their code as part of disentanglement_lib. (As of September 2020 it has been released, but has unresolved discrepencies).
  • disentanglement_lib still uses outdated Tensorflow 1.0, and the flow of data is unintuitive because of its use of Gin Config.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

disent-0.2.0.tar.gz (131.6 kB view hashes)

Uploaded Source

Built Distribution

disent-0.2.0-py3-none-any.whl (244.3 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page